"components/vscode:/vscode.git/clone" did not exist on "c90e3dff7e774d5170a83823eb225214bcc9f9ab"
Commit af7f4372 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1

parents 5e19cdef 09c77926
//
// Based off of:
// cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp
// Specifically:
// https://github.com/NVIDIA/cutlass/tree/06b21349bcf6ddf6a1686a47a137ad1446579db9/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp
// Referred to as upstream from in the comments
//
// The main optimization machete implements compared to upstream is to prepack
// the weight matrix to more closely match the shape of the wgmma instructions
// allowing for wider (ideally 128bit) shared memory loads. For subbyte types
// this is done by packing values from multiple wgmma loads (for a single
// thread) into a single 128bit load. This is very similar to layout used in
// Marlin, although specific to the wgmma instructions.
//
// Since the wgmma instructions only support sourcing from registers for the A
// operand, and we want to upconvert/decompress the weight values/elements
// before feeding them into the tensor cores in registers, we need the weight
// matrix to be A. To achieve this we compute the transpose of Y = XW^t as
// Y^t = W^tX^t. This is mostly done outside of this file in
// csrc/quantization/machete/machete_mm_kernel.cuh, but this why A is the
// quantized/narrow type and has the prepacked layout despite the API being:
// B_prepacked = machete_prepack_B(B)
// Y = machete_mm(A, B_prepacked)
//
#pragma once
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/detail/dependent_false.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/detail/layout.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/atom/copy_traits_sm90_tma.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp"
#include "cutlass/trace.h"
#include "cutlass/detail/collective.hpp"
// clang-format on
#include "cutlass_extensions/cute_utils.cuh"
namespace machete {
using namespace cute;
using namespace cutlass;
using namespace cutlass::gemm;
using namespace cutlass::gemm::collective;
using namespace cutlass::gemm::collective::detail;
template <class ElementATuple_, class GmemLayoutA, int AlignmentA,
class ElementB_, class GmemLayoutB, int AlignmentB,
class ElementAccumulator_, class TileShape_MNK,
class ClusterShape_MNK, class StageCountType,
class KernelScheduleType>
struct MacheteCollectiveMma {
using Schedule = KernelScheduleType;
static_assert(
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedMixedInput> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<Schedule,
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<Schedule,
KernelTmaWarpSpecializedCooperativeMixedInput>,
"KernelSchedule must be one of the warp specialized policies");
public:
static constexpr bool ALayoutIsPrepacked = true;
// Prepacked block shape (N is M in the transposed problem)
using PPBlockShape_MK = typename GmemLayoutA::PPBlockShape_NK;
// Prepacked blocks per dim for a single MMA tile
using PPBlocksPerTile_MK = decltype(make_shape(
size<0>(TileShape_MNK{}) / size<0>(PPBlockShape_MK{}),
size<2>(TileShape_MNK{}) / size<1>(PPBlockShape_MK{})));
using IlvdBlkLayout = typename GmemLayoutA::IlvdBlkLayout;
static_assert(size<0>(TileShape_MNK{}) % size<0>(PPBlockShape_MK{}) == 0,
"M in PPBlockShape_MK must evenly divide M TileShape_MNK");
static_assert(size<2>(TileShape_MNK{}) % size<1>(PPBlockShape_MK{}) == 0,
"K in PPBlockShape_MK must evenly divide K TileShape_MNK");
using ArchTag = arch::Sm90;
using TileShape = TileShape_MNK;
using ClusterShape = ClusterShape_MNK;
using ElementA = deduce_mixed_width_dtype_t<0, ElementATuple_>;
using StrideA = TagToStrideA_t<layout::RowMajor>;
using ElementB = ElementB_;
using StrideB = TagToStrideB_t<GmemLayoutB>;
using ElementAccumulator = ElementAccumulator_;
using ElementMma = ElementB;
using ElementATuple =
cute::conditional_t<!cute::is_tuple<ElementATuple_>::value,
cute::tuple<ElementA>, ElementATuple_>;
static constexpr cute::GMMA::Major GmmaMajorA =
gmma_rs_tag_to_major_A<layout::RowMajor>();
static constexpr cute::GMMA::Major GmmaMajorB =
gmma_rs_tag_to_major_B<GmemLayoutB>();
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeMixedInput>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
TileShape_MNK, GMMA::Major::K, GmmaMajorB>(),
AtomLayoutMNK{}));
private:
//
// the setup section (until "section setup end") contains a combination of
// modified code from (used as a starting point):
// `cutlass/gemm/collective/builders/sm90_gmma_builder.inl`
// `cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp`
// (upstream)
//
// however in-order to simplify the code we combine a lot of the logic from
// `CollectiveMma` and `CollectiveBuilder` into this class, this also makes
// sense given that we have flexibility on layouts here. We also simplify the
// code by only supporting scales and zeros for A (in the transposed problem,
// B from an API perspective), also since we force A to be the narrow type
// (i.e. the type to be upconverted) we can remove all the `SwapAB` logic in
// the upstream also simplifying the code. This section includes new logic
// (compared ustream) for handling the prepacked-A layouts (in the transposed
// problem, B from an API perspective)
//
using ElementScale = deduce_mixed_width_dtype_t<1, ElementATuple_>;
using ElementZero = deduce_mixed_width_dtype_t<2, ElementATuple_>;
static constexpr bool IsANarrow = cutlass::sizeof_bits<ElementA>::value <
cutlass::sizeof_bits<ElementB>::value;
static_assert(IsANarrow,
"A must be the narrow one since its the one that flows through "
"registers.");
public:
static constexpr int PipelineStages =
compute_stage_count_or_override_single_affine_transformed_input<
sm90_smem_capacity_bytes, ElementA, ElementB, ElementScale,
ElementZero, TileShape_MNK>(StageCountType{});
struct DispatchPolicy {
constexpr static int Stages = PipelineStages;
using ClusterShape = ClusterShape_MNK;
using Schedule = KernelScheduleType;
};
using GmemTiledCopyA =
decltype(sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB =
decltype(sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
// ((T, V), (BlocksM, BlocksK), pipe) -> offset
using SmemLayoutA = decltype(GmemLayoutA::TVbNbKL_to_offset(
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
Int<DispatchPolicy::Stages>{})));
using SmemLayoutAtomARowMajor =
decltype(rs_smem_selector<GmmaMajorA, ElementA,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomScale = Layout<
Shape<decltype(cute::shape<0>(SmemLayoutAtomARowMajor{})), cute::Int<1>>>;
using SmemLayoutAtomB =
decltype(rs_smem_selector<GmmaMajorB, ElementB,
decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemCopyAtomA = Copy_Atom<cute::DefaultCopy, ElementA>;
using SmemCopyAtomB = void;
//
// Validity checks
//
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(is_aligned<ElementA, AlignmentA, ElementB, AlignmentB,
tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>,
"Unsupported Toolkit for SM90 Collective Builder\n");
#endif
private:
enum class ConversionMode {
DirectConvert,
ConvertAndScale,
ConvertAndScaleWithZero
};
public:
//
// Type Aliases
//
using KernelSchedule = KernelScheduleType;
// For cases where we can't have a void type, we can use this to allow the
// code to compile when the scale / zero is void.
using NonVoidElementScale =
cute::conditional_t<cute::is_void_v<ElementScale>, float, ElementScale>;
using NonVoidElementZero =
cute::conditional_t<cute::is_void_v<ElementZero>, float, ElementZero>;
// These are always MN major
using StrideScale = cute::Stride<cute::Int<1>, int64_t, int64_t>;
// For cases where we can't have a void scale, we can use this to allow the
// code to compile when the scale is void.
using NonVoidStrideScale =
cute::conditional_t<cute::is_void_v<StrideScale>,
cute::Stride<_1, int64_t, int64_t>, StrideScale>;
static_assert((cutlass::gemm::detail::is_k_major<StrideA>()),
"The transformed matrix (A) must be K-major.");
static_assert((sizeof(ElementB) == 2) ||
(cutlass::gemm::detail::is_k_major<StrideA>() &&
cutlass::gemm::detail::is_k_major<StrideB>()),
"The unscaled element (matrix B) must be 2 bytes OR both "
"inputs must be K-major");
static_assert(cutlass::gemm::detail::is_mn_major<NonVoidStrideScale>(),
"Scale must be MN major [Col Major if A is scaled, Row Major "
"if B is scaled].");
static_assert(std::is_same_v<typename TiledMma::ValTypeC, ElementAccumulator>,
"TiledMma::ValTypeC must be the same as ElementAccumulator.");
using GmemTiledCopyScale = cute::SM90_TMA_LOAD;
using SmemCopyAtomScale = Copy_Atom<cute::DefaultCopy, NonVoidElementScale>;
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
// For all other types, cast to size equivalent uint type to avoid any
// rounding by TMA.
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
using InternalElementA =
cute::conditional_t<ConvertF32toTF32A, tfloat32_t,
uint_bit_t<sizeof_bits_v<ElementA>>>;
using InternalElementB =
cute::conditional_t<ConvertF32toTF32B, tfloat32_t,
uint_bit_t<sizeof_bits_v<ElementB>>>;
using TransformA = cute::identity;
using TransformB = cute::identity;
static constexpr int IsSubbyteA = cute::sizeof_bits_v<InternalElementA> < 8;
using TmaElementA =
cute::conditional_t<IsSubbyteA, uint8_t, InternalElementA>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}),
shape<1>(SmemLayoutAtomScale{})));
static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
"SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomScale{}) == 2,
"SmemLayoutAtomScale must be rank 2");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0,
"SmemLayoutAtomScale must equal the tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0,
"SmemLayoutAtomScale must evenly divide tile k shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutACopy = decltype(tile_to_shape(
SmemLayoutAtomARowMajor{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
// It is assumed that the scales and zero-points share the same smem layout
using SmemLayoutScale = decltype(tile_to_shape(
SmemLayoutAtomScale{},
make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}),
Int<PipelineStages>{})));
// If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major
// only (e.g. tf32, fp32, fp8, int8).
static constexpr bool IsLayoutAmnBmn =
cute::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>,
layout::ColumnMajor> &&
cute::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>,
layout::RowMajor>;
static_assert(DispatchPolicy::Stages >= 2,
"Specialization requires Stages set to value 2 or more.");
static_assert(not cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeA>::value &&
cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeB>::value,
"MMA atom must source A from rmem and B operand from smem_desc "
"for this mainloop.");
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> ||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> ||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
using GmmaSmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
// These two restrictions are related, so we place the assertions together.
// To relax them, we need to handle loading more than 1 row of scales for
// every main loop iteration. We must also handle updating the pipeline
// transaction bytes on the fly. NOTE: Deleting this assertion without
// required changes will cause the code to hang.
static_assert(size<1>(SmemLayoutAtomScale{}) == 1,
"size<1>(SmemLayoutAtomScale) must be 1.");
private:
static constexpr ConversionMode get_conversion_mode() {
if constexpr (cute::is_void_v<ElementScale>) {
return ConversionMode::DirectConvert;
} else if constexpr (cute::is_void_v<ElementZero>) {
return ConversionMode::ConvertAndScale;
} else {
return ConversionMode::ConvertAndScaleWithZero;
}
}
static constexpr ConversionMode KernelConversionMode = get_conversion_mode();
static constexpr bool ModeHasScales =
KernelConversionMode == ConversionMode::ConvertAndScale ||
KernelConversionMode == ConversionMode::ConvertAndScaleWithZero;
// Same as upstream, should be kept the same when possible
static constexpr auto elements_per_smem_scale() {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return 0;
} else if constexpr (ModeHasScales) {
return cute::cosize_v<SmemLayoutScale>;
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Type not handled in scale smem allocation.");
}
}
// Same as upstream, should be kept the same when possible
static constexpr auto elements_per_smem_zero() {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert ||
KernelConversionMode == ConversionMode::ConvertAndScale) {
return 0;
} else if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
return cute::cosize_v<SmemLayoutScale>;
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Type not handled in scale smem allocation.");
}
}
// Same as upstream, should be kept the same when possible, not formatte for
// easier comparison
// clang-format off
// These methods use some the public members of the class. For that reason, we define them after the public section.
static constexpr uint32_t
compute_tma_transaction_bytes_mk() {
constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(cute::sizeof_bits_v<InternalElementA>));
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return baseline_bytes;
}
else if constexpr (ModeHasScales) {
constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast<uint32_t>(cute::sizeof_bits_v<ElementScale>));
static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return baseline_bytes + scale_tx_bytes;
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
// Scale and zero share smem layout
constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast<uint32_t>(cute::sizeof_bits_v<ElementZero>));
static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA
return baseline_bytes + scale_tx_bytes + zero_tx_bytes;
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
}
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
}
}
static constexpr uint32_t
compute_tma_transaction_bytes_nk() {
return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(cute::sizeof_bits_v<InternalElementB>));
}
// clang-format on
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset(
make_shape(int32_t(0), int32_t(0), int32_t(0)))));
using ATensor = decltype(make_tensor(
get_logical_ptr(static_cast<InternalElementA const*>(nullptr)),
shape(GmemLayoutA::TVbNbKL_to_offset(
make_shape(int32_t(0), int32_t(0), int32_t(0)))),
PrepackedStrideA{}));
using BTensor = decltype(make_tensor(
get_logical_ptr(static_cast<InternalElementB const*>(nullptr)),
repeat_like(StrideB{}, int32_t(0)), StrideB{}));
using ScaleTensor = decltype(make_tensor(
get_logical_ptr(static_cast<NonVoidElementScale const*>(nullptr)),
repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}));
using ZeroTensor = decltype(make_tensor(
get_logical_ptr(static_cast<NonVoidElementZero const*>(nullptr)),
repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}));
static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) {
return make_tma_copy<TmaElementA>(
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
shape(SmemLayoutA{}(_, _, cute::Int<0>{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
}
static constexpr auto make_tma_copy_scale(
ScaleTensor tensor_scale = ScaleTensor{}) {
return make_tma_copy(GmemTiledCopyScale{}, tensor_scale,
SmemLayoutScale{}(_, _, cute::Int<0>{}),
ScaleTileShape{},
_1{}); // mcast along N mode for this M load, if any
}
static constexpr auto make_tma_copy_zero(
ZeroTensor tensor_zero = ZeroTensor{}) {
return make_tma_copy(GmemTiledCopyScale{}, tensor_zero,
SmemLayoutScale{}(_, _, cute::Int<0>{}),
ScaleTileShape{},
_1{}); // mcast along N mode for this M load, if any
}
static constexpr auto make_tma_copy_B(BTensor tensor_b = BTensor{}) {
return make_tma_copy(
GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
}
public:
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// with `RealInternalElementA` -> `ElementA` since we support `SwapAB` logic
// clang-format off
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{});
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{});
// Just pick the max alignment of A and B since it is required to be at least 128B
static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB);
static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment");
struct SharedStorage
{
static constexpr int scale_elements = elements_per_smem_scale();
static constexpr int zero_elements = elements_per_smem_zero();
struct TensorStorage : cute::aligned_struct<cute::max(SmemAlignmentA, SmemAlignmentB)> {
cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::ArrayEngine<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
cute::ArrayEngine<NonVoidElementScale, scale_elements> smem_scale;
cute::ArrayEngine<NonVoidElementZero, zero_elements> smem_zero;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Host side kernel arguments
struct Arguments {
ElementA const* ptr_A = nullptr;
StrideA dA{};
ElementB const* ptr_B = nullptr;
StrideB dB{};
ElementScale const* ptr_S = nullptr;
NonVoidStrideScale dS{};
int group_size = 0;
ElementZero const* ptr_Z = nullptr;
uint32_t mma_promotion_interval = 4;
};
// clang-format on
//
// section setup end
//
// Similar (but not idendtical) to upstream, should be kept the same when
// possible
// compared to upstream we use `make_tma_copy_A`, `make_tma_copy_B` etc. to
// define the TMA types
// Device side kernel params
struct Params {
public:
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy_A());
using TMA_Scale = decltype(make_tma_copy_scale());
using TMA_Zero = decltype(make_tma_copy_zero());
using TMA_B = decltype(make_tma_copy_B());
// required by outer loop: i.e.
// cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_Scale tma_load_scale;
TMA_Zero tma_load_zero;
int64_t scale_k;
int group_size;
uint32_t tma_transaction_bytes = TmaTransactionBytes;
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
};
//
// Methods
//
// Similar (but not idendtical) to upstream, should be kept the same when
// possible
// compared to upstream we use `make_tma_copy_A` and `TVbNbKL_to_offset` here
// to handle the prepacked layout
template <class ProblemShape>
static constexpr Params to_underlying_arguments(
ProblemShape const& problem_shape, Arguments const& args,
void* workspace) {
(void)workspace;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is
// only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
auto make_logical_tensor = [&](auto ptr, auto shape, auto stride) {
return make_tensor(get_logical_ptr(ptr), make_layout(shape, stride));
};
typename Params::TMA_A tma_load_a;
typename Params::TMA_B tma_load_b;
typename Params::TMA_Scale tma_load_scale;
typename Params::TMA_Zero tma_load_zero;
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
tma_load_a = make_tma_copy_A(
make_logical_tensor(ptr_A, shape(layout), stride(layout)));
tma_load_b = make_tma_copy_B(
make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB));
if constexpr (ModeHasScales) {
tma_load_scale = make_tma_copy_scale(make_logical_tensor(
args.ptr_S, make_shape(M, args.group_size, L), args.dS));
}
if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
tma_load_zero = make_tma_copy_zero(make_logical_tensor(
args.ptr_Z, make_shape(M, args.group_size, L), args.dS));
}
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0};
} else if constexpr (ModeHasScales) {
auto scale_k = (K + args.group_size - 1) / args.group_size;
return {tma_load_a, tma_load_b, tma_load_scale,
tma_load_zero, scale_k, args.group_size};
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in to_underlying_arguments.");
}
}
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// with `SwapAB ? N : M -> M` since we dont support SwapAB
// clang-format off
template<class ProblemShape>
static bool
can_implement(
ProblemShape const& problem_shape,
[[maybe_unused]] Arguments const& args) {
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
implementable = implementable && (args.ptr_S == nullptr);
implementable = implementable && (args.ptr_Z == nullptr);
}
else if constexpr (ModeHasScales) {
const int scale_mn = M;
const int scale_k = (K + args.group_size - 1) / args.group_size;
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));
implementable = implementable && args.group_size != 0;
implementable = implementable && (args.ptr_S != nullptr);
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
implementable = implementable && (args.ptr_Z == nullptr);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits<ElementZero>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_zero>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
implementable = implementable && (args.ptr_Z != nullptr);
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
}
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk();
static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk();
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// Nothing extra to do
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor());
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor());
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA prefetch.");
}
}
// clang-format off
// Modified from upstream, should be kept close to that when possible
// the main difference is special handling for the prepacked A layout
//
// Set up the data needed by this collective for load and mma.
// Returns a tuple of tensors. The collective and the kernel layer have the
// contract Returned tuple must contain at least two elements, with the first
// two elements being: gA_mkl - The tma tensor, A after a local tile so it
// has shape (TILE_V,TILE_B,m,k,l) gB_nkl - The tma tensor, B after a local
// tile so it has shape (TILE_N,TILE_K,n,k,l) The rest of the tensors can be
// specified as needed by this collective.
// NOTE: TILE_B is the prepacked block index within a tile. TILE_V is the
// values within a prepacked block.
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL,
Params const& mainloop_params) const {
using X = Underscore;
auto M = get<0>(problem_shape_MNKL), N = get<1>(problem_shape_MNKL),
K = get<2>(problem_shape_MNKL), L = get<3>(problem_shape_MNKL);
// (TILE_V,TILE_B,m,k,l)
auto make_gA_mkl = [&]() {
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout));
return local_tile(mA_mkl,
make_shape(size<0>(layout), PPBlocksPerTile_MK{}),
make_coord(0, make_coord(_, _)));
};
// (TILE_N,TILE_K,n,k,l)
auto make_gB_nkl = [&]() {
Tensor mB_nkl =
mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L));
return local_tile(mB_nkl, TileShape{}, make_coord(_, _, _),
Step<X, _1, _1>{});
};
// (TILE_M,TILE_Scale_K,m,scale_k,l)
auto make_gS_mkl = [&]() {
auto scale_k = mainloop_params.scale_k;
Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(
make_shape(M, scale_k, L));
return local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _));
};
// (TILE_M,TILE_Scale_K,m,scale_k,l)
auto make_gZ_mkl = [&]() {
auto scale_k = mainloop_params.scale_k;
Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(
make_shape(M, scale_k, L));
return local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _));
};
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return cute::make_tuple(make_gA_mkl(), make_gB_nkl());
} else if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScale) {
return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl());
} else if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl(),
make_gZ_mkl());
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in load_init.");
}
}
// Similar to upstream, should be kept close to that when possible
// the main difference is in the layout comments
// clang-format off
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
/// This overload gets triggered when we have scales.
template <
class... Ts,
class KTileIterator, class BlockCoord
>
CUTLASS_DEVICE void
load(
Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState smem_pipe_write,
cute::tuple<Ts...> const& load_inputs,
BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs");
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs");
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs");
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA load.");
}
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE)
Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE)
//
// Prepare the TMA loads for A, B and Scales
//
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (TILE_V,TILE_B,k)
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (TILE_N,TILE_K,k)
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_s = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
}
}
auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord);
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// Nothing extra to do.
}
else if constexpr (ModeHasScales) {
auto tSgS = get<0>(extra_input_partitions);
auto tSsS = get<1>(extra_input_partitions);
// Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes
// on the fly.
// We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K
// is a multiple of the threadblock tile K
const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{});
const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K.
copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage));
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
// Nothing extra to do
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
auto tZgZ = get<2>(extra_input_partitions);
auto tZsZ = get<3>(extra_input_partitions);
copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage));
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for TMA copy op.");
}
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for TMA copy op.");
}
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
// clang-format off
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// clang-format off
// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
// clang-format on
// Modified from upstream, should be kept close to that when possible
// the main differences are handling the prepacked A layout, and separating
// the loading of A from upcoverting A
//
// Perform a collective-scoped matrix multiply-accumulate
// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void mma(MainloopPipeline pipeline,
PipelineState smem_pipe_read, FrgTensorC& accum,
int k_tile_count, int thread_idx,
TensorStorage& shared_tensors,
Params const& mainloop_params) {
static_assert(is_rmem<FrgTensorC>::value,
"C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutB{}) == 3,
"Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
"SmemLayoutAtomB must be rank 2.");
static_assert(!cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops must specify a non-void copy atom for "
"RF sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for "
"smem sourced instructions.");
// Obtain warp index
int warp_idx = canonical_warp_idx_sync();
[[maybe_unused]] int warp_group_thread_idx = thread_idx % 128;
// ((T, (FrgV,(RestM, RestK)), (BlocksM, BlocksK), pipe) -> offset
auto constexpr smem_A = SmemLayoutA{};
// convert:
// ((T, (MMA,(MMA_M, MMA_K)), (BlocksM, BlocksK), pipe) -> offset
// to:
// (T, MMA, ((MMA_M, BlocksM), (MMA_K, BlocksK)), pipe) -> offset
// which can be thought of as:
// (T, MMA, (MMA_M, MMA_K), pipe) -> offset
auto constexpr smem_A_mma_ =
make_layout(get<0, 0>(smem_A), get<0, 1, 0>(smem_A),
zip(get<0, 1, 1>(smem_A), get<1>(smem_A)), get<2>(smem_A));
// flatten to:
// (T, MMA, MMA_M, MMA_K, pipe) -> offset
auto constexpr smem_A_mma = smem_A_mma_(_, _, make_coord(_, _), _);
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()),
smem_A_mma); // (T, MMA, MMA_M, MMA_K, pipe)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()),
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = sA(thread_idx, _, _, _, _); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate fragments and descriptors
Tensor tCrA_load = make_tensor<ElementA>(
tCsA(_, _, _, Int<0>{}).shape()); // (MMA,MMA_N,MMA_K)
Tensor tCrA_mma = make_fragment_like<ElementMma>(tCrA_load);
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
static constexpr int A_CPY_VEC =
decltype(max_common_vector(tCsA, tCrA_load)){};
static constexpr int COVERSION_WIDTH =
std::min(A_CPY_VEC, int(size<0>(tCrA_mma)));
auto load_A_to_registers = [&](int read_stage) {
copy(create_auto_vectorizing_copy<ElementA, decltype(A_CPY_VEC)>(),
tCsA(_, _, _, read_stage), tCrA_load(_, _, _));
};
// Partition of thread -> shared and thread -> RF
auto partitioned_extra_info =
partition_extra_mma_info(thread_mma, shared_tensors);
auto copy_partitions_extra_info = retile_extra_mma_info(
tiled_mma, partitioned_extra_info, warp_group_thread_idx);
CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
//
// PIPELINED MAIN LOOP
//
auto convert_A = [&, a_vec = Int<COVERSION_WIDTH>{}](int k_block,
int read_stage) {
load_extra_info_to_registers(partitioned_extra_info,
copy_partitions_extra_info, k_block,
read_stage);
transform_A_kblock(tCrA_load, a_vec, tCrA_mma, partitioned_extra_info,
k_block);
};
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
warpgroup_fence_operand(accum);
constexpr int K_BLOCK_MAX = size<2>(tCrA_load);
ConsumerToken barrier_token = {BarrierStatus::WaitAgain};
// first k tile
{
barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
int read_stage = smem_pipe_read.index();
++smem_pipe_read;
barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
// copy smem->rmem for A operand
load_A_to_registers(read_stage);
convert_A(0, read_stage);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
if (k_block < K_BLOCK_MAX - 1) {
convert_A(k_block + 1, smem_pipe_read.index());
}
warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA_mma(_, _, k_block),
tCrB(_, _, k_block, read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
}
--k_tile_count;
if (k_tile_count > 0) {
// Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to
// overwrite the A registers for the first mma.
warpgroup_wait<K_BLOCK_MAX - 1>();
pipeline.consumer_wait(smem_pipe_read, barrier_token);
load_A_to_registers(smem_pipe_read.index());
convert_A(0, smem_pipe_read.index());
}
}
if (k_tile_count == 0) {
return;
}
warpgroup_fence_operand(accum);
// Mainloop GMMAs
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 1; --k_tile_count) {
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
++smem_pipe_read;
warpgroup_fence_operand(accum);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA_mma(_, _, k_block),
tCrB(_, _, k_block, read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
warpgroup_wait<K_BLOCK_MAX - 1>();
if (k_block == K_BLOCK_MAX - 1) {
// We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage,
// so we can release prior barrier
pipeline.consumer_release(
smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_
// on it
++smem_pipe_release;
}
if (k_block == 0) {
barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
}
if (k_block == K_BLOCK_MAX - 1) {
pipeline.consumer_wait(smem_pipe_read, barrier_token);
load_A_to_registers(smem_pipe_read.index());
convert_A(0, smem_pipe_read.index());
} else {
convert_A(k_block + 1, read_stage);
}
}
warpgroup_fence_operand(accum);
}
warpgroup_fence_operand(accum);
{
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
warpgroup_fence_operand(accum);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA_mma(_, _, k_block),
tCrB(_, _, k_block, read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
warpgroup_wait<K_BLOCK_MAX - 1>();
if (k_block == K_BLOCK_MAX - 1) {
// release prior barrier
pipeline.consumer_release(
smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_
// on it
++smem_pipe_release;
}
if (k_block < K_BLOCK_MAX - 1) {
convert_A(k_block + 1, read_stage);
}
}
}
warpgroup_fence_operand(accum);
}
// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline,
PipelineState smem_pipe_release,
int k_tile_count) {
// Prologue GMMAs
int prologue_mma_count = 1;
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count) {
pipeline.consumer_release(
smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on
// it
++smem_pipe_release;
}
}
private:
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// clang-format off
/// Utilities for any additional inputs inside of the TMA load
template <class... Ts>
CUTLASS_DEVICE
auto partition_extra_tma_inputs(
Params const& mainloop_params,
cute::tuple<Ts...> const& load_inputs,
TensorStorage& shared_tensors,
uint2 const& cluster_local_block_id,
int const m_coord,
int const l_coord) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return cute::make_tuple();
}
else if constexpr (ModeHasScales) {
Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
Tensor gS_mkl = get<2>(load_inputs);
auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y);
Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k)
Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE)
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(tSgS, tSsS);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
Tensor gZ_mkl = get<3>(load_inputs);
auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y);
Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k)
Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE)
return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ);
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
}
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
}
}
// clang-format off
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// clang-format off
/// Utilities for partitioning extra inputs for loading from smem in the mainloop.
template <class ThreadMma>
CUTLASS_DEVICE
auto partition_extra_mma_info(
ThreadMma const& mma_thread_slice,
TensorStorage& shared_tensors) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// nothing to do
return cute::make_tuple();
}
else if constexpr (ModeHasScales) {
Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
Tensor tCsS = mma_thread_slice.partition_A(sS);
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape());
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(tCsS, tCrS);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
Tensor tCsZ = mma_thread_slice.partition_A(sZ);
Tensor tCrZ = make_tensor<ElementZero>(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape());
return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ);
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
// clang-format on
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// clang-format off
/// Returns the tiled copy and copy views for the extra inputs.
template <class TiledMma, class... Ts>
CUTLASS_DEVICE
auto retile_extra_mma_info(
TiledMma const& tiled_mma,
cute::tuple<Ts...>& partitioned_extra_info,
int const warp_group_thread_idx) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// nothing to do
return cute::make_tuple();
}
else if constexpr (ModeHasScales) {
auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma);
auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx);
Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view);
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
// clang-format on
// Similar to `copy_A_and_extra_info` upstream, should be kept the same when
// possible
// the main differences this only loads the extra info into registers and
// not A (since we now preload more of A in the main pipeline)
// Load scales and zeros into registers if required
template <class... Ts, class... Us>
CUTLASS_DEVICE void load_extra_info_to_registers(
cute::tuple<Ts...> const& partitioned_mma_extra_info,
cute::tuple<Us...> const& tiled_copy_and_views, int k_block,
int read_stage) {
if (k_block == 0) {
// We are starting a new k-tile so copy the scale
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// nothing to do
} else if constexpr (ModeHasScales) {
auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views);
auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views);
auto tCsS = cute::get<0>(partitioned_mma_extra_info);
copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage),
tCrS_copy_view(_, _, k_block));
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
// Nothing extra to do
} else if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
auto tCsZ = cute::get<2>(partitioned_mma_extra_info);
auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views);
copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage),
tCrZ_copy_view(_, _, k_block));
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in A -> RF path.");
}
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in A -> RF path.");
}
}
}
// Similar to upstream, should be kept the same when possible.
// the main differences are that `convert_tensor` supports interleaved
// layouts and bfloat16 has been optimized. `transform_internal_A` has also
// been inlined for code simplicity.
// Utilities to transform A.
template <class TCrA_load, int VectorWidthA, class TCrA_mma, class... Ts>
CUTLASS_DEVICE void transform_A_kblock(
TCrA_load const& tCrA_load, cute::Int<VectorWidthA> vec_A,
TCrA_mma& tCrA_mma, cute::tuple<Ts...> const& partitioned_extra_info,
int const k_block) {
auto in = tCrA_load(_, _, k_block);
auto out = tCrA_mma(_, _, k_block);
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
convert_tensor<IlvdBlkLayout>(in, out, vec_A);
} else if constexpr (ModeHasScales) {
auto tCrS = cute::get<1>(partitioned_extra_info);
auto converted_inputs =
make_fragment_like<ElementScale>(tCrA_mma)(_, _, k_block);
auto scales = tCrS(_, _, 0);
// First, we upcast the inputs to the scale type
convert_tensor<IlvdBlkLayout>(in, converted_inputs, vec_A);
// Apply scales and broadcast across inputs, store in converted_inputs
// We need to cast to nv_bfloat16 for the multiply since
// `cutlass::bfloat16_t` has an overloaded operator* that upconverts to
// float, which nvcc will not optimize to using vectorized fma
// instructions (i.e. hfma.bf16_v2)
if constexpr (std::is_same_v<ElementScale, cutlass::bfloat16_t>) {
cute::transform(
recast<nv_bfloat16>(converted_inputs), recast<nv_bfloat16>(scales),
recast<nv_bfloat16>(converted_inputs), cute::multiplies{});
} else {
cute::transform(converted_inputs, scales, converted_inputs,
cute::multiplies{});
}
// Apply zeros if required
if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
auto tCrZ = cute::get<3>(partitioned_extra_info);
auto converted_zeros = make_fragment_like<ElementScale>(tCrZ)(_, _, 0);
convert_tensor<void>(tCrZ(_, _, 0), converted_zeros);
if constexpr (std::is_same_v<ElementScale, cutlass::bfloat16_t>) {
cute::transform(recast<nv_bfloat16>(converted_inputs),
recast<nv_bfloat16>(converted_zeros),
recast<nv_bfloat16>(converted_inputs), cute::plus{});
} else {
cute::transform(converted_inputs, converted_zeros, converted_inputs,
cute::plus{});
}
}
// Finally, we convert the scaled inputs to the mma type.
convert_tensor<void>(converted_inputs, out);
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"No A data is loaded.");
}
}
// Modified from upstream, should be kept the same when possible
// the main differences is that this version supports interleaved converts
// Utilities for transforming the A operand prior to issuing tensorcore math.
template <typename IlvdBlkLayout, class EngineIn, class EngineOut,
class TensorLayout,
int ConversionVectorWidth = cosize_v<TensorLayout>>
CUTLASS_DEVICE void convert_tensor(
Tensor<EngineIn, TensorLayout> const& in,
Tensor<EngineOut, TensorLayout>& out,
cute::Int<ConversionVectorWidth> width = {}) {
// This is an element-wise conversion where we expect both tensors to have
// the same layout. As a result, we can cast as a cutlass array to use the
// fast numeric converters without worrying about indexing into the layout.
constexpr int N = cosize_v<TensorLayout>;
// The inputs must be backed by registers & be statically sized.
static_assert(is_rmem<EngineIn>::value,
"Input tensor for A conversion must come from registers");
static_assert(is_rmem<EngineOut>::value,
"Output tensor for A conversion must come from registers");
static_assert(is_static_v<TensorLayout>,
"Tensor layout for the conversion must be static");
static_assert(cosize_v<TensorLayout> == size(TensorLayout{}),
"Cosize and size of the layout must be equal.");
static_assert(
N % ConversionVectorWidth == 0,
"Conversion vector width must divide cosize of the tensor layout.");
using SrcType = typename EngineIn::value_type;
using DstType = typename EngineOut::value_type;
using SrcArray = cutlass::Array<SrcType, ConversionVectorWidth>;
using DstArray = cutlass::Array<DstType, ConversionVectorWidth>;
constexpr cutlass::FloatRoundStyle RoundStyle =
cutlass::FloatRoundStyle::round_to_nearest;
using Converter = cutlass::InterleavedNumericArrayConverter<
IlvdBlkLayout, DstType, SrcType, ConversionVectorWidth, RoundStyle>;
constexpr int NumIterations = N / ConversionVectorWidth;
for (int ii = 0; ii < NumIterations; ++ii) {
SrcArray const* src_array_ptr =
reinterpret_cast<SrcArray const*>(raw_pointer_cast(in.data())) + ii;
DstArray* dst_array_ptr =
reinterpret_cast<DstArray*>(raw_pointer_cast(out.data())) + ii;
*dst_array_ptr = Converter::convert(*src_array_ptr);
}
}
};
} // namespace machete
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
#include "machete_collective_builder.cuh"
#include "machete_prepacked_layout.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
// NOTE This kernel computes D = alpha * A * B + beta * C by computing
// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma
// instructions only support sourcing from registers for the left-hand
// operand, we want to upconvert/decompress the quantized operand in
// register. Since the primary use case we want to support is Y = XW^t where
// W is quantized, in this situation or right-hand operand is quantized so
// we compute the transpose to move it to the left-hand side.
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, typename ScaleT, typename ZeroT,
class KernelSchedule, typename ScheduleConfig, bool with_C,
bool with_scales, bool with_zeropoints>
struct MacheteKernelTemplate {
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementC = cute::conditional_t<with_C, ElementD, void>;
using ElementZ = ZeroT;
using ElementS = ScaleT;
using ElementAccumulator =
AccumulatorT; // Element type for internal accumulation
using ElementCompute = AccumulatorT; // For Epilogue
using BTypeTuple = cute::conditional_t<
with_scales,
cute::conditional_t<with_zeropoints,
cute::tuple<ElementB, ElementS, ElementZ>,
cute::tuple<ElementB, ElementS>>,
ElementB>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
using LayoutScale = cutlass::layout::RowMajor;
// not actually used since B has the prepacked layout, but required by cutlass
using _LayoutB = cutlass::layout::ColumnMajor;
// Interface strides expected by create_arguments (will get transposed)
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using StrideS = cutlass::detail::TagToStrideA_t<LayoutScale>;
using StrideZ = StrideS;
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutC_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using PrepackedLayoutB =
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT,
LayoutA_Transpose, KernelSchedule>;
static int constexpr TileShapeK =
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v<ElementA>;
static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v<ElementB>;
static int constexpr AlignmentC =
(with_C) ? 128 / cutlass::sizeof_bits_v<ElementC> : 0;
static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v<ElementD>;
using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{},
cute::Int<TileShapeK>{}));
using ClusterShape = typename ScheduleConfig::ClusterShape;
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
using TileScheduler = typename ScheduleConfig::TileScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose,
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::VLLMCollectiveBuilder<
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// stride_B is unused (since B is prepacked), but still required by cutlass
using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;
using Arguments = typename Gemm::Arguments;
using MainloopArguments = typename GemmKernel::MainloopArguments;
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
template <typename ShapeA, typename ShapeC, typename ShapeD, typename ShapeS,
typename ShapeZ>
static Arguments create_arguments(
cudaStream_t stream,
ElementA const* A_ptr, // A is an MxK matrix
Layout<ShapeA, StrideA> const& layout_A,
ElementB const* B_ptr, // B is an KxN prepacked matrix
ElementD* D_ptr, // D is an MxN matrix
Layout<ShapeD, StrideD> const& layout_D,
ElementC const* C_ptr, // C is an MxN matrix
std::optional<Layout<ShapeC, StrideC>> const& layout_C,
ElementS const* S_ptr, // S is an scale_KxN matrix
std::optional<Layout<ShapeS, StrideS>> const& layout_S,
ElementZ const* Z_ptr, // Z is an scale_KxN matrix
std::optional<Layout<ShapeZ, StrideZ>> const& layout_Z,
ElementCompute alpha, ElementCompute beta,
std::optional<int> maybe_group_size) {
static_assert(!with_zeropoints || with_scales);
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
int const group_size = maybe_group_size.value_or(K);
int const scale_k = (K + group_size - 1) / group_size;
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
if constexpr (with_C) {
TORCH_CHECK(C_ptr && layout_C);
} else {
TORCH_CHECK(!C_ptr, "C not supported");
}
if constexpr (with_scales) {
TORCH_CHECK(S_ptr && layout_S);
TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N));
} else {
TORCH_CHECK(!S_ptr, "Scales not supported");
}
if constexpr (with_zeropoints) {
TORCH_CHECK(Z_ptr && layout_Z);
TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N));
TORCH_CHECK(layout_S && *layout_Z == *layout_S,
"Scales and zeros must have the same layout");
} else {
TORCH_CHECK(!Z_ptr, "Zeropoints not supported");
}
// Transpose A and D
// A doesn't need to be transposed since cutlass expects a NxK matrix
// for B (which is At)
auto stride_At = layout_A.stride();
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
auto stride_Ct = stride_Dt;
if (layout_C) {
stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride();
}
MainloopArguments mainloop_arguments{};
EpilogueArguments epilogue_arguments{
{alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt};
if constexpr (with_scales && with_zeropoints) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_ptr, stride_S, group_size, Z_ptr};
} else if constexpr (with_scales) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments = MainloopArguments{
B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size};
} else {
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
}
return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K, 1},
mainloop_arguments,
epilogue_arguments};
};
static size_t get_workspace_size(Arguments const& args) {
return Gemm::get_workspace_size(args);
}
static bool can_implement(Arguments const& args) {
return Gemm::can_implement(args) == cutlass::Status::kSuccess;
}
static void run(Arguments const& args, void* workspace, cudaStream_t stream) {
Gemm gemm_op;
cutlass::Status status = gemm_op.initialize(args, workspace, stream);
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Machete kernel failed to initialize workspace");
status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
}
};
}; // namespace machete
#pragma once
#include <torch/all.h>
#include <Python.h>
#include "machete_mm_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp"
namespace machete {
struct PyTorchArguments {
torch::Tensor const& A;
torch::Tensor const& B;
c10::optional<torch::Tensor> const& scales;
c10::optional<torch::Tensor> const& zeros;
c10::optional<int64_t> group_size;
c10::optional<torch::Tensor> const& C;
c10::optional<double> alpha;
c10::optional<double> beta;
c10::optional<std::string> schedule;
};
template <typename MacheteKernel>
torch::Tensor run_impl(PyTorchArguments args) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
auto device = args.A.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
using EleA = typename MacheteKernel::ElementA;
using EleB = typename MacheteKernel::ElementB;
using EleC = typename MacheteKernel::ElementC;
using EleD = typename MacheteKernel::ElementD;
using EleScale = typename MacheteKernel::ElementS;
using EleZero = typename MacheteKernel::ElementZ;
using StrideA = typename MacheteKernel::StrideA;
using StrideC = typename MacheteKernel::StrideC;
using StrideD = typename MacheteKernel::StrideD;
using StrideS = typename MacheteKernel::StrideS;
using StrideZ = typename MacheteKernel::StrideZ;
int M = args.A.size(0);
int N = args.B.size(1);
int K = args.A.size(1);
// Allocate output
torch::Tensor D =
torch::empty({M, N}, torch::TensorOptions()
.dtype(equivalent_scalar_type_v<EleD>)
.device(device));
auto const &A = args.A, &B = args.B;
auto const &C = args.C, &scales = args.scales, &zeros = args.zeros;
auto layout_A = make_cute_layout<StrideA>(A, "A");
auto layout_D = make_cute_layout<StrideD>(D, "D");
auto layout_C = maybe_make_cute_layout<StrideC>(C, "C");
auto layout_S = maybe_make_cute_layout<StrideS>(scales, "scales");
auto layout_Z = maybe_make_cute_layout<StrideZ>(zeros, "zeros");
auto A_ptr = static_cast<EleA const*>(A.const_data_ptr());
auto B_ptr = static_cast<EleB const*>(B.const_data_ptr());
auto D_ptr = static_cast<EleD*>(D.mutable_data_ptr());
auto C_ptr = static_cast<EleC const*>(C ? C->const_data_ptr() : nullptr);
auto S_ptr =
static_cast<EleScale const*>(scales ? scales->const_data_ptr() : nullptr);
auto Z_ptr =
static_cast<EleZero const*>(zeros ? zeros->const_data_ptr() : nullptr);
auto arguments = MacheteKernel::create_arguments(
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
args.group_size.value_or(K));
TORCH_CHECK(MacheteKernel::can_implement(arguments),
"Machete kernel cannot be run with these arguments");
size_t workspace_size = MacheteKernel::get_workspace_size(arguments);
torch::Tensor workspace = torch::empty(
workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device));
MacheteKernel::run(arguments, workspace.mutable_data_ptr(), stream);
return D;
};
template <typename ElementA, typename ElementB, typename ElementD = ElementA,
typename AccumulatorT = float, typename ScaleT = ElementA,
typename ZeroT = ElementA>
struct GemmDispatcher {
static torch::Tensor dispatch(PyTorchArguments args);
static std::vector<std::string> supported_schedules();
};
}; // namespace machete
\ No newline at end of file
#pragma once
#include "machete_mm_kernel.cuh"
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/torch_utils.hpp"
namespace machete {
template <typename TileShapeNKL, typename ElementB, typename BInTensor,
typename BTiledOutTensor>
static __global__ void prepack_B_kernel(BInTensor B_in,
BTiledOutTensor B_tiled_out) {
auto tB_in = local_tile(B_in, TileShapeNKL{},
make_coord(blockIdx.x, blockIdx.y, blockIdx.z));
auto tB_out = B_tiled_out(make_coord(_, _),
make_coord(blockIdx.x, blockIdx.y), blockIdx.z);
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, ElementB>{},
Layout<Shape<_4, _32>, Stride<_32, _1>>{},
Layout<Shape<_1, _2>>{});
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
Tensor thr_tile_S = thr_copy.partition_S(tB_in);
Tensor thr_tile_D = thr_copy.partition_D(tB_out);
// Construct a register-backed Tensor with the same shape as each thread's
// partition
auto fragment = make_tensor<ElementB>(shape(thr_tile_D));
// Copy from GMEM to RMEM and from RMEM to GMEM
copy(tiled_copy, thr_tile_S, fragment);
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tile_D);
}
template <typename PrepackedLayoutB, typename InLayout>
static void prepack_B(cudaStream_t stream,
typename PrepackedLayoutB::ElementB const* B_in_ptr,
InLayout B_layout,
typename PrepackedLayoutB::ElementB* B_out_ptr) {
using TileShapeNKL =
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
auto ilvd_NKbNbKL_to_offset =
PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout));
TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0);
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{});
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
auto B_tiled_out =
make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset);
prepack_B_kernel<TileShapeNKL, typename PrepackedLayoutB::ElementB>
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_tiled_out);
}
}; // namespace machete
\ No newline at end of file
#pragma once
#include "machete_prepack_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp"
namespace machete {
template <typename PrepackedLayoutB>
torch::Tensor prepack_impl(torch::Tensor const B) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
using ElementB = typename PrepackedLayoutB::ElementB;
using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK;
auto device = B.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
// elements per storage item for B
auto eles_per_storage =
(B.dtype().itemsize() * 8) / cute::sizeof_bits_v<ElementB>;
// torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to
// match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L)
auto Bt_packed = B.t();
TORCH_CHECK(
(B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0,
"B.shape[0] (in terms of unpacked elements) must be a multiple of ",
size<1>(PPBlockShape_NK{}));
TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0,
"B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{}));
using StrideB = cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>;
auto const l_Bt_packed = make_cute_layout<StrideB>(Bt_packed, "B");
// convert (N,packed_K,L) layout to (N,K,L) layout
// in effect we want to do: blocked_product(layout_Bt_packed,
// make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}),
// Step<_1, _0, _2>{}));
// but blocked_product does not support dynamic strides so we implement the
// equivalent manually,
// new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L)
// new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage)
// when s1 == 1
TORCH_CHECK(stride<1>(l_Bt_packed) == 1);
// clang-format off
auto const layout_Bt = make_layout(
transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) {
return idx == 1 ? ele * eles_per_storage : ele;
}),
transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) {
return idx != 1 ? ele * eles_per_storage : ele;
}));
// clang-format on
// Allocate output
torch::Tensor D = torch::empty_like(B);
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
static_cast<ElementB*>(D.mutable_data_ptr()));
return D;
};
template <typename ElementA, typename ElementB, typename ElementD,
typename AccumulatorT = float, typename ScaleT = cutlass::half_t,
typename ZeroT = cutlass::half_t>
struct PrepackBDispatcher {
static torch::Tensor dispatch(torch::Tensor B);
};
}; // namespace machete
\ No newline at end of file
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "cutlass_extensions/cute_utils.cuh"
#include "machete_collective_builder.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
struct IlvBlkLayoutAuto {};
// This defines a prepacked layout for the B matrix, where the matrix is broken
// up into PPBlockShape_NK blocks. The data within each block is then compactly
// stored in memory such that when performing a TiledMMA operation with the same
// shape as prepacked block, all the data for a given thread is contiguous in
// memory. This allows us to use wider shared memory loads when loading B from
// shared memory. The values within a thread are also potentially interlaeved
// inorder to allow for more efficient upconverting.
//
// The contract here is that the `TiledMma` determined below matches the one
// ultimately used in the kernel. (this is also why the other element types are
// required along with the kernel schedule)
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, class LayoutB, class KernelSchedule,
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
// clang-format on
struct PrepackedLayoutBTemplate {
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementAccumulator =
AccumulatorT; // Element type for internal accumulation
using ElementMma = MmaType;
// Only use interleaved layouts for subbyte weights, prmt instructions makes
// non-interleaved layouts for 8bit+ weights efficient enough we don't need
// iterleaved layouts
using IlvdBlkLayout = std::conditional_t<
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
std::conditional_t<sizeof_bits_v<ElementB> <= 4,
decltype(get_interleaved_blk_layout<
ElementB, sizeof_bits_v<ElementA>, 32>()),
void>,
IlvBlkLayout_>;
// TODO (LucasWilkinson): compare the performance for other sizes
// Prepacked block shape, smallest layout atom for loading into registers
// (can contain multiple wgmma instructions worth of data in one block)
// We ideally want this to be configured such that a thread can perform 128bit
// loads, i.e. we amount of data associated with each thread within a
// prepacked block is a multiple of 128bits, when using a cooperative sechdule
// we have 256 threads working a single block at a time, this means each
// thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
// for a 4bit type this would be 128bits
using PPBlockShape_NK = Shape<_128, _64>;
// Create the shape of the tile anticipated to be used by the GEMM kernel,
// when the kernel executes we will compute `Ct = Bt * At` since the
// quantized weights (B), must be the lhs operand so the flow through
// registers.
// The _128 here doesn't actually impact the shape of the stored tile directly
// but may impact the op selected by rs_op_selector
using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
size<1>(PPBlockShape_NK{})));
static constexpr cute::GMMA::Major GmmaMajorB =
gmma_rs_tag_to_major_B<LayoutB>();
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelSchedule,
KernelTmaWarpSpecializedCooperativeMixedInput>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
AtomLayoutMNK{}));
// Prepacked block, (athrid, val) -> (N,K)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() {
return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{}));
}
// Prepacked block, (N,K) -> (athrid, val)
// i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() {
return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() {
// Return iterleaved layout
return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() {
auto layout_no_interleave =
make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
if constexpr (std::is_same_v<IlvdBlkLayout, void>) {
return layout_no_interleave;
} else {
// interleave by transforming FrgV into interleaved blocks where each
// block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
// (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
// if FrgV is {A, B, C, D, E, F, G, H}
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
auto frgV = get<1, 0>(layout_no_interleave);
auto ilvdBlk = IlvdBlkLayout{};
static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4");
auto ilvd_FrgV = make_layout(
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
make_stride(stride(ilvdBlk), size(ilvdBlk)));
// Return iterleaved layout
return make_layout(
get<0>(layout_no_interleave),
make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave)));
}
}
// Prepacked block, (M,K) -> (storage_offset)
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() {
// do (M,K) -> (athrid, val) -> (storage_idx)
return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV());
}
// ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset(
Shape_NKL shape_mkl) {
constexpr auto block_layout = ppblock_TV_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
block_layout,
make_layout(blocks_shape,
compact_col_major(blocks_shape, size(block_layout))));
// ((athrid, val), (BlocksN, BlocksK, L))
// => ((athrid, val), (BlocksN, BlocksK), L)
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
Shape_NKL shape_mkl) {
constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
block_layout,
make_layout(blocks_shape,
compact_col_major(blocks_shape, size(block_layout))));
// ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
// BlocksK), L)
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
template <class Shape_NKL>
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})),
make_layout(size<1>(PPBlockShape_NK{})));
// ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
auto tiled_A = zipped_divide(make_layout(shape_mkl), tile);
return tiled_A.compose(ppblock_TV_to_NK(), _);
}
// (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
template <class Shape_NKL>
CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) {
auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl);
return blocked_product(ppblock_NK_to_TV(),
make_layout(shape<1>(TVbNbK_to_NKL_layout)));
}
};
}; // namespace machete
\ No newline at end of file
#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
#include "core/scalar_type.hpp"
namespace machete {
using namespace vllm;
//
// Utils (type dispatching)
//
template <typename Fn>
static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
if (type == vllm::kU4) {
return fn(cutlass::uint4b_t{});
} else if (type == vllm::kU8) {
return fn(cutlass::uint8_t{});
} else if (type == vllm::kU4B8) {
return fn(cutlass::vllm_uint4b8_t{});
} else if (type == vllm::kU8B128) {
return fn(cutlass::vllm_uint8b128_t{});
} else {
TORCH_CHECK(false, "Unsupported type ", type.str());
}
}
#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)
#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))
//
// Interface
//
std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
return scalar_type_dispatch(*btype, [&](auto BType) {
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
ScalarTypeTorchPtr const& btype,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
auto args = PyTorchArguments{.A = A,
.B = B,
.scales = scales,
.zeros = zeros,
.group_size = group_size,
.C = C,
.alpha = alpha,
.beta = beta,
.schedule = schedule};
return scalar_type_dispatch(*btype, [&](auto BType) {
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
A.scalar_type(), "machete_gemm", [&] {
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
});
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
torch::Tensor prepack_B(torch::Tensor const& B,
ScalarTypeTorchPtr const& btype) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
return scalar_type_dispatch(*btype, [&](auto BType) {
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
}; // namespace machete
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_compat.h"
namespace vllm {
namespace detail {
template <typename T>
__inline__ __device__ T _max(T a, T b) {
return max(a, b);
}
template <typename T>
__inline__ __device__ T _sum(T a, T b) {
return a + b;
}
} // namespace detail
template <typename T>
using ReduceFnType = T (*)(T, T);
// Helper function to return the next largest power of 2
static constexpr int _nextPow2(unsigned int num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
template <typename T, int numLanes = WARP_SIZE>
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!");
static_assert(numLanes <= WARP_SIZE);
#pragma unroll
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask));
return val;
}
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
static_assert(maxBlockSize <= 1024);
if constexpr (maxBlockSize > WARP_SIZE) {
val = warpReduce<T>(val, fn);
// Calculates max number of lanes that need to participate in the last
// warpReduce
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
static __shared__ T shared[maxActiveLanes];
int lane = threadIdx.x % WARP_SIZE;
int wid = threadIdx.x / WARP_SIZE;
if (lane == 0) shared[wid] = val;
__syncthreads();
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
: (T)(0.0f);
val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
} else {
// A single warpReduce is equal to blockReduce
val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
}
return val;
}
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceMax(T val) {
return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
}
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) {
return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
}
} // namespace vllm
......@@ -198,6 +198,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
ops.def("machete_supported_schedules", &machete::supported_schedules);
ops.def(
"machete_gemm(Tensor A, Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype,"
" Tensor? scales, Tensor? zeros, int? group_size,"
" Tensor? C, float? alpha, float? beta, str? schedule)"
"-> Tensor");
ops.impl("machete_gemm", torch::kCUDA, &machete::gemm);
ops.def(
"machete_prepack_B(Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype)"
"-> Tensor");
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
......@@ -210,6 +225,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("awq_marlin_repack", &awq_marlin_repack);
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
// Dequantization for GGML.
ops.def("ggml_dequantize", &ggml_dequantize);
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
// mmvq kernel for GGML.
ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8);
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
// mmq kernel for GGML.
ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8);
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
......@@ -219,13 +246,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization.
// quantization, as well as bias
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
......
......@@ -3,9 +3,10 @@ sphinx-book-theme==1.0.1
sphinx-copybutton==0.5.2
myst-parser==2.0.0
sphinx-argparse==0.4.0
msgspec
# packages to install to build the documentation
pydantic
pydantic >= 2.8
-f https://download.pytorch.org/whl/cpu
torch
py-cpuinfo
......
......@@ -20,6 +20,7 @@ vLLM is a community project. Our compute resources for development and testing a
- Roblox
- RunPod
- Sequoia Capital
- Skywork AI
- Trainy
- UC Berkeley
- UC San Diego
......
......@@ -97,13 +97,13 @@ def setup(app):
# Mock out external dependencies here, otherwise the autodoc pages may be blank.
autodoc_mock_imports = [
"aiohttp",
"compressed_tensors",
"cpuinfo",
"torch",
"transformers",
"psutil",
"prometheus_client",
"sentencepiece",
"vllm.cuda_utils",
"vllm._C",
"PIL",
"numpy",
......@@ -112,6 +112,10 @@ autodoc_mock_imports = [
"tensorizer",
"pynvml",
"outlines",
"librosa",
"soundfile",
"gguf",
"lark",
]
for mock_target in autodoc_mock_imports:
......
......@@ -17,4 +17,4 @@ Input Processing Pipeline
6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision language model.
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision model.
......@@ -15,6 +15,9 @@ by following :ref:`this guide <adding_multimodal_plugin>`.
Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here <enabling_multimodal_inputs>`.
..
TODO: Add usage of --limit-mm-per-prompt when multi-image input is officially supported
Guides
++++++
......
Profiling vLLM
=================================
We support tracing vLLM workers using the ``torch.profiler`` module. You can enable tracing by setting the ``VLLM_TORCH_PROFILER_DIR`` environment variable to the directory where you want to save the traces: ``VLLM_TORCH_PROFILER_DIR=/mnt/traces/``
The OpenAI server also needs to be started with the ``VLLM_TORCH_PROFILER_DIR`` environment variable set.
When using ``benchmarks/benchmark_serving.py``, you can enable profiling by passing the ``--profile`` flag.
.. warning::
Only enable profiling in a development environment.
Traces can be visualized using https://ui.perfetto.dev/.
.. tip::
Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly.
Example commands:
OpenAI Server:
.. code-block:: bash
VLLM_TORCH_PROFILER_DIR=/mnt/traces/ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B
benchmark_serving.py:
.. code-block:: bash
python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-70B --dataset-name sharegpt --dataset-path sharegpt.json --profile --num-prompts 2
\ No newline at end of file
......@@ -30,24 +30,59 @@ Here are some common issues that can cause hangs:
.. code-block:: python
# Test PyTorch NCCL
import torch
import torch.distributed as dist
dist.init_process_group(backend="nccl")
local_rank = dist.get_rank() % torch.cuda.device_count()
data = torch.FloatTensor([1,] * 128).to(f"cuda:{local_rank}")
torch.cuda.set_device(local_rank)
data = torch.FloatTensor([1,] * 128).to("cuda")
dist.all_reduce(data, op=dist.ReduceOp.SUM)
torch.cuda.synchronize()
value = data.mean().item()
world_size = dist.get_world_size()
assert value == world_size, f"Expected {world_size}, got {value}"
print("PyTorch NCCL is successful!")
# Test PyTorch GLOO
gloo_group = dist.new_group(ranks=list(range(world_size)), backend="gloo")
cpu_data = torch.FloatTensor([1,] * 128)
dist.all_reduce(cpu_data, op=dist.ReduceOp.SUM, group=gloo_group)
value = cpu_data.mean().item()
assert value == world_size, f"Expected {world_size}, got {value}"
print("sanity check is successful!")
print("PyTorch GLOO is successful!")
# Test vLLM NCCL, with cuda graph
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank)
pynccl.disabled = False
s = torch.cuda.Stream()
with torch.cuda.stream(s):
data.fill_(1)
pynccl.all_reduce(data, stream=s)
value = data.mean().item()
assert value == world_size, f"Expected {world_size}, got {value}"
print("vLLM NCCL is successful!")
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(cuda_graph=g, stream=s):
pynccl.all_reduce(data, stream=torch.cuda.current_stream())
data.fill_(1)
g.replay()
torch.cuda.current_stream().synchronize()
value = data.mean().item()
assert value == world_size, f"Expected {world_size}, got {value}"
print("vLLM NCCL with cuda graph is successful!")
dist.destroy_process_group(gloo_group)
dist.destroy_process_group()
.. tip::
......
......@@ -68,6 +68,16 @@ You can also build and install vLLM from source:
$ cd vllm
$ pip install -e . # This may take 5-10 minutes.
.. note::
vLLM can fully run only on Linux, but you can still build it on other systems (for example, macOS). This build is only for development purposes, allowing for imports and a more convenient dev environment. The binaries will not be compiled and not work on non-Linux systems. You can create such a build with the following commands:
.. code-block:: console
$ export VLLM_TARGET_DEVICE=empty
$ pip install -e .
.. tip::
Building from source requires quite a lot compilation. If you are building from source for multiple times, it is beneficial to cache the compilation results. For example, you can install `ccache <https://github.com/ccache/ccache>`_ via either `conda install ccache` or `apt install ccache` . As long as `which ccache` command can find the `ccache` binary, it will be used automatically by the build system. After the first build, the subsequent builds will be much faster.
......
......@@ -57,7 +57,7 @@ Install from source
.. code-block:: console
$ PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/pre-release" VLLM_TARGET_DEVICE=openvino python -m pip install -v .
$ PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE=openvino python -m pip install -v .
.. _openvino_backend_performance_tips:
......@@ -70,7 +70,7 @@ vLLM OpenVINO backend uses the following environment variables to control behavi
- ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform.
- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off.
- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `<model_id>`
To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``)
......@@ -91,5 +91,3 @@ Limitations
- Only LLM models are currently supported. LLaVa and encoder-decoder models are not currently enabled in vLLM OpenVINO integration.
- Tensor and pipeline parallelism are not currently enabled in vLLM integration.
- Speculative sampling is not tested within vLLM integration.
......@@ -8,7 +8,7 @@ vLLM supports Google Cloud TPUs using PyTorch XLA.
Requirements
------------
* Google Cloud TPU VM (single host)
* Google Cloud TPU VM (single & multi host)
* TPU versions: v5e, v5p, v4
* Python: 3.10
......@@ -56,7 +56,7 @@ First, install the dependencies:
$ pip uninstall torch torch-xla -y
$ # Install PyTorch and PyTorch XLA.
$ export DATE="+20240726"
$ export DATE="+20240808"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
......@@ -65,7 +65,7 @@ First, install the dependencies:
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
$ # Install other build dependencies.
$ pip install packaging aiohttp
$ pip install -r requirements-tpu.txt
Next, build vLLM from source. This will only take a few seconds:
......
......@@ -31,8 +31,10 @@ vLLM is fast with:
* Efficient management of attention key and value memory with **PagedAttention**
* Continuous batching of incoming requests
* Fast model execution with CUDA/HIP graph
* Quantization: `GPTQ <https://arxiv.org/abs/2210.17323>`_, `AWQ <https://arxiv.org/abs/2306.00978>`_, `SqueezeLLM <https://arxiv.org/abs/2306.07629>`_, FP8 KV Cache
* Optimized CUDA kernels
* Quantization: `GPTQ <https://arxiv.org/abs/2210.17323>`_, `AWQ <https://arxiv.org/abs/2306.00978>`_, INT4, INT8, and FP8
* Optimized CUDA kernels, including integration with FlashAttention and FlashInfer.
* Speculative decoding
* Chunked prefill
vLLM is flexible and easy to use with:
......@@ -41,9 +43,9 @@ vLLM is flexible and easy to use with:
* Tensor parallelism and pipeline parallelism support for distributed inference
* Streaming outputs
* OpenAI-compatible API server
* Support NVIDIA GPUs and AMD GPUs
* (Experimental) Prefix caching support
* (Experimental) Multi-lora support
* Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron.
* Prefix caching support
* Multi-lora support
For more information, check out the following:
......@@ -53,7 +55,6 @@ For more information, check out the following:
* :ref:`vLLM Meetups <meetups>`.
Documentation
-------------
......@@ -106,6 +107,7 @@ Documentation
quantization/supported_hardware
quantization/auto_awq
quantization/bnb
quantization/int8
quantization/fp8
quantization/fp8_e5m2_kvcache
quantization/fp8_e4m3_kvcache
......@@ -134,6 +136,7 @@ Documentation
dev/input_processing/model_inputs_index
dev/multimodal/multimodal_index
dev/dockerfile/dockerfile
dev/profiling/profiling_index
.. toctree::
:maxdepth: 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment