Unverified Commit 222ce6f1 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

add tensorrt_llm common and cutlass_extensions as 3rdparty (#3216)


Co-authored-by: default avatarBBuf <35585791+BBuf@users.noreply.github.com>
parent 468d23cf
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp"
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
class GmemLayoutB, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType, template <class /* ElementCompute */> class Activation,
bool SwapAB = false, class Enable = void>
struct CollectiveBuilderGated
{
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/detail/dependent_false.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA, class ElementB, class StrideB,
class TiledMma, class GmemTiledCopyA, class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB, class TransformB,
template <class /* ElementCompute */> class Activation, bool SwapAB = false>
struct CollectiveMmaGated
{
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp"
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <int Stages, class ClusterShape, class KernelSchedule, class TileShape_, class ElementA_, class StrideA_,
class ElementB_, class StrideB_, class TiledMma_, class GmemTiledCopyA_, class SmemLayoutAtomA_,
class SmemCopyAtomA_, class TransformA_, class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
class TransformB_, template <class /* ElementCompute */> class Activation_, bool SwapAB_>
struct CollectiveMmaGated<MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>, TileShape_,
ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_>
{
static constexpr bool isGated = true;
static constexpr bool SwapAB = SwapAB_;
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
using StrideB = StrideB_;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using Activation = Activation_<ElementAccumulator>;
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA, typename TiledMma::ValTypeB>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value
&& cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
static_assert(
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
using InternalElementAux = cute::conditional_t<SwapAB, InternalElementA, InternalElementB>;
struct SharedStorage
{
struct TensorStorage : cute::aligned_struct<128>
{
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Host side kernel arguments
struct Arguments
{
ElementA const* ptr_A;
StrideA dA;
ElementB const* ptr_B;
StrideB dB;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
// Device side kernel params
struct Params
{
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{},
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{},
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_Aux tma_load_aux;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params to_underlying_arguments(
ProblemShape const& problem_shape, Arguments const& args, void* workspace)
{
(void) workspace;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a,
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b,
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
if constexpr (SwapAB)
{
auto ptr_Aux = reinterpret_cast<InternalElementA const*>(args.ptr_A + size(make_shape(M, K, L)));
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux,
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1};
}
else
{
auto ptr_Aux = reinterpret_cast<InternalElementB const*>(args.ptr_B + size(make_shape(N, K, L)));
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux,
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1};
}
}
template <class ProblemShape>
static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args)
{
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable = implementable
&& cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M, K, L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable
&& cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N, K, L), StrideB{});
if (!implementable)
{
CUTLASS_TRACE_HOST(
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytes
= (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value)) / 8
+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value)) / 8
+ (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast<uint32_t>(sizeof_bits<ElementAux>::value))
/ 8;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params)
{
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
/// Returned tuple must contain at least two elements, with the first two elements being:
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
/// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
/// The rest of the tensors can be specified as needed by this collective.
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const
{
using X = Underscore;
// Separate out problem shape for convenience
auto [M, N, K, L] = problem_shape_MNKL;
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
if constexpr (SwapAB)
{
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
Tensor gAux_xkl
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
else
{
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
Tensor gAux_xkl
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <class TensorA, class TensorB, class TensorAux, class KTileIterator, class BlockCoord>
CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB, TensorAux> const& load_inputs, BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors)
{
int lane_predicate = cute::elect_one_sync();
if (lane_predicate)
{
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id
= {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
: mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_aux = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>)
{
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n)
{
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)
{
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m)
{
mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{}));
}
}
if constexpr (SwapAB)
{
mcast_mask_aux = mcast_mask_a;
}
else
{
mcast_mask_aux = mcast_mask_b;
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter),
tAsA(_, _, _, write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter),
tBsB(_, _, _, write_stage));
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter),
tAuxsAux(_, _, _, write_stage));
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write)
{
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate)
{
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0,
FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors,
Params const& mainloop_params)
{
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutAux{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
auto tCsAux = [&]() -> auto
{
if constexpr (SwapAB)
{
return thread_mma.partition_A(sAux);
}
else
{
return thread_mma.partition_B(sAux);
}
}();
auto tCrAux = [&]() -> auto
{
if constexpr (SwapAB)
{
return thread_mma.make_fragment_A(tCsAux);
}
else
{
return thread_mma.make_fragment_B(tCsAux);
}
}();
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
if constexpr (SwapAB)
{
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
}
else
{
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
}
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sAux)); // PIPE
//
// PIPELINED MAIN LOOP
//
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight");
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
int read_stage = smem_pipe_read.index();
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0);
if constexpr (SwapAB)
{
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1);
}
else
{
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1);
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
++smem_pipe_read;
}
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0);
if constexpr (SwapAB)
{
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1);
}
else
{
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1);
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
// UNLOCK smem_pipe_release, done _computing_ on it
pipeline.consumer_release(smem_pipe_release);
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count)
{
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count)
{
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/collective/fp8_accumulation.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <int Stages, class ClusterShape, class KernelSchedule, class TileShape_, class ElementA_, class StrideA_,
class ElementB_, class StrideB_, class TiledMma_, class GmemTiledCopyA_, class SmemLayoutAtomA_,
class SmemCopyAtomA_, class TransformA_, class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
class TransformB_, template <class /* ElementCompute */> class Activation_, bool SwapAB_>
struct CollectiveMmaGated<MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>, TileShape_,
ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_>
{
static constexpr bool isGated = true;
static constexpr bool SwapAB = SwapAB_;
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
using StrideB = StrideB_;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using Activation = Activation_<ElementAccumulator>;
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA, typename TiledMma::ValTypeB>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value
&& cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
static_assert(
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
struct SharedStorage
{
struct TensorStorage : cute::aligned_struct<128>
{
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Host side kernel arguments
struct Arguments
{
ElementA const* ptr_A;
StrideA dA;
ElementB const* ptr_B;
StrideB dB;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
// Device side kernel params
struct Params
{
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{},
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_, _, 0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{},
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_, _, 0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_Aux tma_load_aux;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params to_underlying_arguments(
ProblemShape const& problem_shape, Arguments const& args, void* workspace)
{
(void) workspace;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a,
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b,
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
if constexpr (SwapAB)
{
auto ptr_Aux = reinterpret_cast<ElementA const*>(args.ptr_A + size(make_shape(M, K, L)));
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux,
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval};
}
else
{
auto ptr_Aux = reinterpret_cast<ElementB const*>(args.ptr_B + size(make_shape(N, K, L)));
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux,
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval};
}
}
template <class ProblemShape>
static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args)
{
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable = implementable
&& cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M, K, L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable
&& cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N, K, L), StrideB{});
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA
* instructions. */
implementable = implementable && (args.mma_promotion_interval % 4 == 0);
if (!implementable)
{
CUTLASS_TRACE_HOST(
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytes
= (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value)) / 8
+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value)) / 8
+ (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast<uint32_t>(sizeof_bits<ElementAux>::value))
/ 8;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params)
{
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
/// Returned tuple must contain at least two elements, with the first two elements being:
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
/// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const
{
using X = Underscore;
// Separate out problem shape for convenience
auto [M, N, K, L] = problem_shape_MNKL;
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
if constexpr (SwapAB)
{
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
Tensor gAux_xkl
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
else
{
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
Tensor gAux_xkl
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <class TensorA, class TensorB, class TensorAux, class KTileIterator, class BlockCoord>
CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB, TensorAux> const& load_inputs, BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors)
{
int lane_predicate = cute::elect_one_sync();
if (lane_predicate)
{
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id
= {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
: mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_aux = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>)
{
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n)
{
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)
{
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m)
{
mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{}));
}
}
if constexpr (SwapAB)
{
mcast_mask_aux = mcast_mask_a;
}
else
{
mcast_mask_aux = mcast_mask_b;
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter),
tAsA(_, _, _, write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter),
tBsB(_, _, _, write_stage));
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter),
tAuxsAux(_, _, _, write_stage));
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write)
{
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate)
{
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0,
FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors,
Params const& mainloop_params)
{
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
auto tCsAux = [&]() -> auto
{
if constexpr (SwapAB)
{
return thread_mma.partition_A(sAux);
}
else
{
return thread_mma.partition_B(sAux);
}
}();
auto tCrAux = [&]() -> auto
{
if constexpr (SwapAB)
{
return thread_mma.make_fragment_A(tCsAux);
}
else
{
return thread_mma.make_fragment_B(tCsAux);
}
}();
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
if constexpr (SwapAB)
{
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
}
else
{
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
}
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sAux)); // PIPE
//
// PIPELINED MAIN LOOP
//
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight");
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
GmmaFP8Accumulation accumulation0(accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA));
GmmaFP8Accumulation accumulation1(accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA));
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
if (accumulation0.prepare_if_needed())
{
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
int read_stage = smem_pipe_read.index();
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(
tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0());
if constexpr (SwapAB)
{
cute::gemm(
tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1());
}
else
{
cute::gemm(
tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1());
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
accumulation0.promote_if_needed();
accumulation1.promote_if_needed();
++smem_pipe_read;
}
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
if (accumulation0.prepare_if_needed())
{
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(
tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0());
if constexpr (SwapAB)
{
cute::gemm(
tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1());
}
else
{
cute::gemm(
tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1());
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
accumulation0.promote_if_needed();
accumulation1.promote_if_needed();
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
accumulation0.promote_residue_if_needed();
accumulation1.promote_residue_if_needed();
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count)
{
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count)
{
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
*/
#pragma once
// #include <limits>
#include "cutlass/arch/arch.h"
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace device
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
that feature at the moment.
*/
template <typename GemmKernel_>
class GemmUniversalBaseCompat
{
public:
using GemmKernel = GemmKernel_;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using ElementA = typename GemmKernel::ElementA;
using LayoutA = typename GemmKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
using ElementB = typename GemmKernel::ElementB;
using LayoutB = typename GemmKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
using ElementC = typename GemmKernel::ElementC;
using LayoutC = typename GemmKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
using Operator = typename GemmKernel::Operator;
/// Argument structure
using Arguments = typename GemmKernel::Arguments;
protected:
/// Kernel parameters object
typename GemmKernel::Params params_;
protected:
/// Private helper to obtain the grid dimensions with fix-up for split-K
static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args)
{
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
gemm_k_size = args.problem_size.k();
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
{
int const kAlignK
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size)
{
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
}
public:
/// Constructs the GEMM.
GemmUniversalBaseCompat() {}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const& args)
{
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax))
{
return Status::kErrorInvalidProblem;
}
return GemmKernel::can_implement(args);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const& args)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
size_t workspace_bytes = 0;
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
if (args.mode == GemmUniversalMode::kGemmSplitKParallel)
{
// Split-K parallel always requires a temporary workspace
workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
}
else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1)
{
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
}
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
return workspace_bytes;
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const& args)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n"
<< " result = {" << result << "}");
return result;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
if (smem_size <= (48 << 10))
{
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
if (result == cudaSuccess)
{
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
}
else
{
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
if (smem_capacity < 0)
{
int device_idx = 0;
result = cudaGetDevice(&device_idx);
if (result != cudaSuccess)
{
return -1;
}
cudaDeviceProp properties;
result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess)
{
return -1;
}
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
}
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
return occupancy;
}
CUTLASS_TRACE_HOST(" returning internal error");
return -1;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
size_t workspace_bytes = get_workspace_size(args);
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
if (workspace_bytes)
{
if (!workspace)
{
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
return Status::kErrorWorkspaceNull;
}
if (args.mode == GemmUniversalMode::kGemm)
{
CUTLASS_TRACE_HOST(" clearing device workspace");
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
}
// Get CUDA grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
// Initialize the Params structure
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int*>(workspace));
// Specify shared memory capacity for kernel.
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10))
{
cudaError_t result
= cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const& args, void* workspace = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace)
{
return Status::kErrorWorkspaceNull;
}
params_.update(args, workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
//
// Configure grid and block dimensions
//
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
//
// Launch kernel
//
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
// Launch
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
//
// Query for errors
//
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr)
{
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
{
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess)
{
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
*/
#pragma once
#include <limits>
#include <numeric>
#include <vector>
#include "cutlass/arch/arch.h"
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace device
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T_IN, typename T_OUT>
__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk,
int64_t* splitk_buffer_offsets)
{
// in_tensor: [problem_idx, k_partition, hidden_size]
// Note that different requests of in_tensor might have different hidden_size (=m*n)
// so, we need to use splitk_buffer_offsets.
// out_tensor: problem_idx * [hidden_size]
int const problem_idx = blockIdx.y;
GemmCoord problem = problem_sizes[problem_idx];
int const hidden_size = problem.m() * problem.n();
const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk;
T_OUT* out_tensor_ = out_tensor[problem_idx];
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x)
{
float sum = 0.0f;
for (int k_idx = 0; k_idx < splitk; k_idx++)
{
sum += (float) in_tensor_[k_idx * hidden_size + i];
}
out_tensor_[i] = (T_OUT) (sum);
}
}
/// GEMM Grouped
template <typename BaseKernel_>
class BaseSplitkGrouped
{
public:
using BaseKernel = BaseKernel_;
using ElementA = typename BaseKernel::ElementA;
using LayoutA = typename BaseKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
static int const kAlignmentA = BaseKernel::kAlignmentA;
using ElementB = typename BaseKernel::ElementB;
using LayoutB = typename BaseKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
static int const kAlignmentB = BaseKernel::kAlignmentB;
using ElementC = typename BaseKernel::ElementC;
using LayoutC = typename BaseKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
static int const kAlignmentC = BaseKernel::kAlignmentC;
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle;
using Operator = typename BaseKernel::Operator;
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
using MathOperator = typename WarpMmaOperator::MathOperator;
using OperatorClass = typename WarpMmaOperator::OperatorClass;
using ArchTag = typename WarpMmaOperator::ArchTag;
using ThreadblockShape = typename BaseKernel::Mma::Shape;
using WarpShape = typename BaseKernel::WarpShape;
using InstructionShape = typename BaseKernel::InstructionShape;
static int const kStages = BaseKernel::Mma::kStages;
/// Argument structure
using Arguments = typename BaseKernel::Arguments;
using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo;
protected:
/// Kernel parameters object
typename BaseKernel::Params gemm_params_;
private:
/// Get the number of tiles across all problems in a group
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count)
{
int32_t tiles = 0;
for (int32_t i = 0; i < problem_count; ++i)
{
cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i];
BaseKernel::ProblemVisitor::possibly_transpose_problem(problem);
tiles += problem_tile_count(problem);
}
return tiles;
}
/// Copy from `data` to `workspace`
Status copy_to_workspace(void* workspace, void* data, size_t bytes)
{
cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
if (cuda_error != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
cuda_error = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Precomputes scheduling information for the grouped GEMM
Status precompute(Arguments const& args, int32_t tile_count, void* workspace)
{
size_t workspace_bytes = get_workspace_size(args);
std::vector<uint8_t> host_workspace(workspace_bytes);
BaseKernel::ProblemVisitor::host_precompute(
args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data());
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
}
/// Reorder `data` according to `indices`
template <typename T>
static void reorder_array(T* data, std::vector<size_t> const& indices)
{
// For now, simply create a copy of the data and then copy over to the original.
std::vector<T> copy(indices.size());
for (size_t i = 0; i < indices.size(); ++i)
{
copy.at(i) = data[indices[i]];
}
memcpy(data, copy.data(), indices.size() * sizeof(T));
}
public:
/// Constructs the GEMM.
BaseSplitkGrouped() {}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const& args)
{
return BaseKernel::can_implement(args);
}
/// Get the number of tiles in a problem
static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem)
{
auto grid = BaseKernel::ProblemVisitor::grid_shape(problem);
return BaseKernel::ProblemVisitor::tile_count(grid);
}
/// Get the number of tiles across all problems in a group
static int32_t group_tile_count(Arguments const& args)
{
if (args.host_problem_sizes == nullptr)
{
CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes");
return -1;
}
return group_tile_count(args.host_problem_sizes, args.problem_count);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const& args)
{
size_t total_mn = 0;
for (int i = 0; i < args.problem_count; i++)
{
total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n();
}
size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices;
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
{
workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size(
args.host_problem_sizes, args.problem_count, args.threadblock_count);
}
return workSpaceSize;
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const& args)
{
return dim3(args.threadblock_count, 1, 1);
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1)
{
CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()");
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
cudaError_t result;
if (smem_size > (48 << 10))
{
result = cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
return -1;
}
}
int max_active_blocks = -1;
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<BaseKernel>, BaseKernel::kThreadCount, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Sorts each pointer passed in according to the indices that sort
/// `problem_sizes_ptr` in descending order of problem-K dimension.
static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr,
int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr,
int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr)
{
std::vector<size_t> indices(problem_count);
std::iota(indices.begin(), indices.end(), 0);
std::stable_sort(indices.begin(), indices.end(),
[&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); });
reorder_array(problem_sizes_ptr, indices);
reorder_array(lda_host_ptr, indices);
reorder_array(ldb_host_ptr, indices);
reorder_array(ldc_host_ptr, indices);
reorder_array(ldd_host_ptr, indices);
reorder_array(offset_A_ptr, indices);
reorder_array(offset_B_ptr, indices);
reorder_array(offset_C_ptr, indices);
reorder_array(offset_D_ptr, indices);
}
/// Computes the number of threadblocks to launch for the grouped kernel
static int sufficient(
cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1)
{
// Determine the number of blocks that would be launched to fill up a single
// wave on the GPU with each SM having maximum occupancy.
int device_idx;
cudaError_t result = cudaGetDevice(&device_idx);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result));
return 0;
}
int multiprocessor_count;
result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx);
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result));
return 0;
}
bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count);
if (override_sm_count)
{
available_sm_count = multiprocessor_count;
}
int max_active_blocks = maximum_active_blocks();
if (max_active_blocks <= 0)
{
return 0;
}
int occupancy_based_block_count = available_sm_count * max_active_blocks;
if (problem_sizes_ptr == nullptr || problem_count == 0)
{
return occupancy_based_block_count;
}
int total_tiles = group_tile_count(problem_sizes_ptr, problem_count);
// If the group contains a single problem, launching the exact number of
// threadblocks needed to cover the problem minimizes the work performed
// per threadblock in finding the next tile to compute. We return total_tiles
// unless the user has provided the SM count.
if (problem_count == 1 && override_sm_count)
{
return total_tiles;
}
// Choose between the full wave of threadblocks and the tile count. If there
// are fewer tiles in the group than threadblocks in the full wave, only
// some threadblocks will be assigned tiles. Those threadblocks
// which are not assigned tiles still need to perform the work of iterating through
// problem sizes to determine that they have no work to do. This competes for cycles
// with those threadblocks that are assigned tiles to compute.
return std::min(total_tiles, occupancy_based_block_count);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
{
CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Workspace
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace)
{
return Status::kErrorWorkspaceNull;
}
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
{
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
if (status != Status::kSuccess)
{
return status;
}
gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count);
}
else
{
gemm_params_ = typename BaseKernel::Params(args, workspace);
}
// Specify shared memory capacity for kernel.
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
if (smem_size >= (48 << 10))
{
cudaError_t result
= cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const& args, void* workspace = nullptr)
{
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace)
{
return Status::kErrorWorkspaceNull;
}
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
{
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
if (status != Status::kSuccess)
{
return status;
}
gemm_params_.update(args, workspace, tile_count);
}
else
{
gemm_params_.update(args, workspace);
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr)
{
if (!gemm_params_.problem_visitor.problem_count)
{
return Status::kSuccess;
}
//
// Launch kernel
//
// Launch splitk grouped gemm
{
dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices);
dim3 block(BaseKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
cutlass::Kernel<BaseKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
// Launch splitkReduction
{
dim3 grid(32, gemm_params_.problem_visitor.problem_count);
dim3 block(256);
splitkReduction<<<grid, block, 0, stream>>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split,
gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices,
gemm_params_.splitk_buffer_offsets);
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr)
{
return run(stream);
}
/// Initializes and runs the kernel.
Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr)
{
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess)
{
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM Grouped
template <typename GemmKernel_>
class SplitkGemmGrouped : public BaseSplitkGrouped<GemmKernel_>
{
public:
using GemmKernel = GemmKernel_;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/half.h"
#include "cutlass/layout/matrix.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
namespace cutlass
{
namespace gemm
{
namespace kernel
{
template <typename TypeA, typename TypeB, typename arch, typename Enable = void>
struct MixedGemmArchTraits
{
static_assert(dependent_false<arch>, "Unrecognised parameterization");
};
template <typename Arch>
struct MixedGemmArchTraits<float, float, Arch>
{
static constexpr int Stages = 2;
using OperatorClass = cutlass::arch::OpClassSimt;
using AccType = float;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int ElementsPerAccessA = 1;
static constexpr int ElementsPerAccessB = 1;
static constexpr int ElementsPerAccessC = 1;
static constexpr int ThreadblockK = 8;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// ======================= Turing Traits ==============================
// Note that turing does not have native bfloat support so weights and activations will be casted to fp16
// and compute will happen in fp16 then will be converted for bf16 output.
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm75,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm75>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = typename LayoutDetails::Operator;
};
// ======================= Ampere Traits ==============================
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm80,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm80>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Operator = typename LayoutDetails::Operator;
};
// ======================= Ada Traits ==============================
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
using Operator = typename LayoutDetails::Operator;
};
// FP8 A/B = fp8, C/D = fp32
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::float_e4m3_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::float_e5m2_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
// be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t<TypeA>
using TypeC = __nv_bfloat16;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeC>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
using Operator = typename LayoutDetails::Operator;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
namespace cutlass
{
namespace gemm
{
namespace kernel
{
template <typename arch>
struct Int8GemmArchTraits
{
using OperatorClass = cutlass::arch::OpClassSimt;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
};
// ======================= Turing Traits ==============================
template <>
struct Int8GemmArchTraits<cutlass::arch::Sm75>
{
using OperatorClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
};
// ======================= Ampere Traits ==============================
template <>
struct Int8GemmArchTraits<cutlass::arch::Sm80>
{
using OperatorClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/permute.h"
#include "splitk_gemm_grouped.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Whether the schedule of problems to visit has been precomputed
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly,
/// Operation performed by GEMM
typename Operator = typename device::DefaultGemmConfiguration<OperatorClass, ArchTag, ElementA_, ElementB_,
ElementC_, ElementAccumulator>::Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Permute result D
typename PermuteDLayout = layout::NoPermute,
///
typename Enable = void>
struct DefaultSplitkGemmGrouped;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Real-valued GEMM kernels
//
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Whether the schedule of problems to visit has been precomputed
GroupScheduleMode GroupScheduleMode_,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Permute result D
typename PermuteDLayout>
struct DefaultSplitkGemmGrouped<ElementA, LayoutA,
ComplexTransform::kNone, // transform A
kAlignmentA, ElementB, LayoutB,
ComplexTransform::kNone, // transform B
kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape,
InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, GroupScheduleMode_, Operator, SharedMemoryClear,
PermuteDLayout, typename platform::enable_if<!cutlass::is_complex<ElementAccumulator>::value>::type>
{
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
static bool const kInternalTranspose = platform::is_same<LayoutC, layout::ColumnMajor>::value;
using MapArguments = kernel::detail::MapArguments<ElementA, LayoutA, ComplexTransform::kNone, kAlignmentA, ElementB,
LayoutB, ComplexTransform::kNone, kAlignmentB, LayoutC, kInternalTranspose>;
// Define the default GEMM kernel
using DefaultGemmKernel = typename kernel::DefaultGemm<typename MapArguments::ElementA,
typename MapArguments::LayoutA, MapArguments::kAlignmentA, typename MapArguments::ElementB,
typename MapArguments::LayoutB, MapArguments::kAlignmentB, ElementC, typename MapArguments::LayoutC,
ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
ThreadblockSwizzle, Stages, true, Operator, SharedMemoryClear, false, /*GatherA*/
false, /*GatherB*/
false, /*ScatterD*/
PermuteDLayout>::GemmKernel;
/// Define the kernel in terms of the default kernel
using GemmKernel = kernel::SplitkGemmGrouped<typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue,
ThreadblockSwizzle, GroupScheduleMode_, kInternalTranspose>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include <type_traits>
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail
{
template <typename>
inline constexpr bool dependent_false_v = false;
}
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
/// arch.
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
>
struct GemmFpAIntB
{
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Element;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Mma::LayoutC;
using ElementScale = ElementC;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformA;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
/// Parameters structure
struct Arguments
{
GemmUniversalMode mode = GemmUniversalMode::kGemm;
cutlass::gemm::GemmCoord problem_size;
int group_size;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::TensorRef ref_B;
typename Mma::IteratorScale::TensorRef ref_scale;
typename Mma::IteratorScale::TensorRef ref_zero;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
// Control serial split-k
int batch_count;
typename EpilogueOutputOp::Params output_op;
// For gather+scatter operations
int const* gather_A_indices;
int const* gather_B_indices;
int const* scatter_D_indices;
// Included so we can use Gemm Universal
int batch_stride_D = 0;
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments() {}
CUTLASS_HOST_DEVICE
Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size,
typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B,
typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor,
typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(),
int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr,
int const* scatter_D_indices = nullptr)
: problem_size(problem_size)
, group_size(group_size)
, ref_A(ref_A)
, ref_B(ref_B)
, ref_scale(ref_scale)
, ref_zero(ref_zero)
, ref_C(ref_C)
, ref_D(ref_D)
, batch_count(serial_split_k_factor)
, output_op(output_op)
, gather_A_indices(gather_A_indices)
, gather_B_indices(gather_B_indices)
, scatter_D_indices(scatter_D_indices)
{
}
};
/// Parameters structure
struct Params
{
cutlass::gemm::GemmCoord problem_size;
int group_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::Params params_B;
typename Mma::IteratorB::TensorRef ref_B;
typename Mma::IteratorScale::Params params_scale;
typename Mma::IteratorScale::TensorRef ref_scale;
typename Mma::IteratorScale::TensorRef ref_zero;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
typename EpilogueOutputOp::Params output_op;
int* semaphore;
int gemm_k_size;
// For gather+scatter operations
int const* gather_A_indices;
int const* gather_B_indices;
int const* scatter_D_indices;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: swizzle_log_tile(0)
, semaphore(0)
, gemm_k_size(0)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size,
void* workspace = nullptr)
: problem_size(args.problem_size)
, group_size(args.group_size)
, grid_tiled_shape(grid_tiled_shape)
, swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape))
, params_A(args.ref_A.layout())
, ref_A(args.ref_A)
, params_B(args.ref_B.layout())
, ref_B(args.ref_B)
, params_scale(args.ref_scale.layout())
, ref_scale(args.ref_scale)
, ref_zero(args.ref_zero)
, params_C(args.ref_C.layout())
, ref_C(args.ref_C)
, params_D(args.ref_D.layout())
, ref_D(args.ref_D)
, output_op(args.output_op)
, semaphore(static_cast<int*>(workspace))
, gemm_k_size(gemm_k_size)
, gather_A_indices(args.gather_A_indices)
, gather_B_indices(args.gather_B_indices)
, scatter_D_indices(args.scatter_D_indices)
{
}
};
/// Shared memory storage structure
union SharedStorage
{
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
GemmFpAIntB() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(Arguments const& args)
{
static int const kAlignmentA
= (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ? 32
: (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB
= (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ? 32
: (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements;
static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Epilogue::OutputTileIterator::kElementsPerAccess;
if (!TensorRef_aligned(args.ref_A, kAlignmentA))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_B, kAlignmentB))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_scale, kAlignmentScale))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_zero, kAlignmentScale))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_C, kAlignmentC))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_D, kAlignmentC))
{
return Status::kErrorMisalignedOperand;
}
if (!args.ref_scale.good())
{
return Status::kErrorNotSupported;
}
if constexpr (hasZero(Mma::QuantOp))
{
if (!args.ref_zero.good())
{
return Status::kErrorNotSupported;
}
}
else
{
if (args.ref_zero.good())
{
return Status::kErrorNotSupported;
}
}
if constexpr (isFinegrained(Mma::QuantOp))
{
if (args.group_size != 64 && args.group_size != 128)
{
return Status::kErrorNotSupported;
}
}
return Status::kSuccess;
}
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
{
return 0;
}
// Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator
// has a different constructor signature than a regular cutlass iterator
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<isFinegrained(op), bool> = true>
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
typename IteratorScale::TensorCoord extent, int thread_id,
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
{
return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size);
}
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<!isFinegrained(op), bool> = true>
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
typename IteratorScale::TensorCoord extent, int thread_id,
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
{
return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset);
}
CUTLASS_DEVICE
void run_kernel_(Params const& params, SharedStorage& shared_storage)
{
using LayoutB = typename Mma::IteratorB::Layout;
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
{
return;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size,
};
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64;
typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0;
cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(),
{params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices);
typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(),
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B,
params.gather_B_indices);
typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1;
typename Mma::IteratorScale iterator_scale = initialize_scale<typename Mma::IteratorScale, Mma::QuantOp>(
params.params_scale, params.ref_scale.data(), params.ref_zero.data(),
{scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations > 0)
{
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
}
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(),
thread_idx, threadblock_offset, params.scatter_D_indices);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(),
thread_idx, threadblock_offset, params.scatter_D_indices);
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k())
{
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1)
{
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else
{
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
{
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
{
run_kernel_(params, shared_storage);
}
else
{
CUTLASS_NOT_IMPLEMENTED();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ == 890)
run_kernel<arch::Sm89>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 900)
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
#else
static_assert(
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
#endif
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cutlass/gemm/kernel/gemm_grouped_problem_visitor.h>
#include <cutlass/trace.h>
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh>
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
namespace fused_moe
{
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int MaxTileM_, int TileN_,
int TileK_, int Stages_, Activation_Type activation_type_>
struct Fused_Moe_Kernel_sm80
{
static constexpr int kMaxTileM = MaxTileM_;
static constexpr int kTileN = isGateActivation(activation_type_) ? TileN_ / 2 : TileN_;
static constexpr int kTileK = TileK_;
static constexpr int kStages = Stages_;
static constexpr Activation_Type activation_type = activation_type_;
using ElementInput = ElementInput_;
using ElementWeight = ElementWeight_;
using ElementOutput = ElementOutput_;
using BaseKernelTraits = Fused_Moe_Kernel_traits_sm80<ElementInput, ElementWeight, ElementOutput, kMaxTileM, kTileN,
kTileK, kStages, activation_type>;
using Routine_Arguments = Routine_Arguments<ElementInput, ElementWeight, ElementOutput>;
using Routine_Params = Routine_Params<ElementInput, ElementWeight, ElementOutput>;
using ProblemVisitor
= cutlass::gemm::kernel::MoeProblemVisitor<cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper<
cutlass::gemm::GemmShape<kMaxTileM, kTileN, kTileK>, false>,
cutlass::gemm::GemmShape<kMaxTileM, kTileN, kTileK>, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
BaseKernelTraits::kThreadCount, BaseKernelTraits::kThreadCount>;
struct Arguments
{
Routine_Arguments routine_args;
int problem_count{};
int threadblock_count{};
};
struct Params
{
Routine_Params routine_params;
int threadblock_count{};
typename ProblemVisitor::Params problem_visitor_param;
};
using BaseKernelTraits_m16 = Fused_Moe_Kernel_traits_sm80<ElementInput, ElementWeight, ElementOutput, 16, kTileN,
kTileK, kStages, activation_type>;
static constexpr bool use_m16 = TileK_ >= 64; // use tileshape m = 16 when original tileshape k >= 64
static constexpr int kSmemSize = use_m16
? (BaseKernelTraits::kSmemSize > BaseKernelTraits_m16::kSmemSize ? BaseKernelTraits::kSmemSize
: BaseKernelTraits_m16::kSmemSize)
: BaseKernelTraits::kSmemSize;
static constexpr int kThreadCount = BaseKernelTraits::kThreadCount;
static constexpr bool can_implement(int const avaliable_smem_size)
{
return BaseKernelTraits::can_implement(avaliable_smem_size);
}
static Params to_underlying_arguments(Arguments const& args)
{
return {
{args.routine_args.ptr_input, args.routine_args.ptr_fc1, args.routine_args.ptr_bias,
args.routine_args.ptr_output, args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n,
args.routine_args.gemm_k, args.routine_args.num_expert, args.routine_args.bias_is_broadcast},
args.threadblock_count,
{args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, args.routine_args.gemm_k,
args.problem_count, nullptr, 0}};
}
CUTE_DEVICE
void run_device(Params const& params)
{
#define ROUTINE_PATH(kTileM_size) \
{ \
constexpr int kTileM = use_m16 ? (kTileM_size) : ((kTileM_size) == 16 ? 32 : (kTileM_size)); \
using RoutineTraits = Fused_Moe_Kernel_routine_sm80<ElementInput, ElementWeight, ElementOutput, kTileM, \
kTileN, kTileK, kStages, activation_type>; \
RoutineTraits routine{}; \
int const block_m_idx = (block_m_idx_temp) *kMaxTileM / kTileM; \
routine.run_routine(params.routine_params, problem_index, block_m_idx, block_n_idx, gemm_m); \
}
typename ProblemVisitor::SharedStorage dummy_storage{};
ProblemVisitor problem_visitor(params.problem_visitor_param, dummy_storage, blockIdx.x);
while (problem_visitor.next_tile())
{
auto problem_size = problem_visitor.problem_size();
auto grid_size = problem_visitor.grid_shape(problem_size);
auto problem_index = problem_visitor.problem_index();
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
int const gemm_m = problem_size.m();
const int32_t block_m_idx_temp = cta_idx / grid_size.n();
const int32_t block_n_idx = cta_idx % grid_size.n();
int const residue_m = gemm_m - kMaxTileM * block_m_idx_temp;
if (residue_m > kMaxTileM / 2)
{
using RoutineTraits = Fused_Moe_Kernel_routine_sm80<ElementInput, ElementWeight, ElementOutput,
kMaxTileM, kTileN, kTileK, kStages, activation_type>;
RoutineTraits routine{};
routine.run_routine(params.routine_params, problem_index, block_m_idx_temp, block_n_idx, gemm_m);
}
else
{
if constexpr (kMaxTileM >= 128)
{
if (residue_m > 32)
{
ROUTINE_PATH(64);
}
else if (residue_m > 16)
{
ROUTINE_PATH(32);
}
else
{
// TODO: use cuda core gemm here
ROUTINE_PATH(16);
}
}
else if (kMaxTileM == 64)
{
if (residue_m > 16)
{
ROUTINE_PATH(32);
}
else
{
// TODO: use cuda core gemm here
ROUTINE_PATH(16);
}
}
else if (kMaxTileM == 32)
{
// TODO: use cuda core gemm here
ROUTINE_PATH(16);
}
else
{
// TODO: use cuda core gemm here
ROUTINE_PATH(16);
}
}
problem_visitor.advance(gridDim.x);
}
#undef ROUTINE_PATH
}
};
template <typename GemmType>
__global__ void run_global(__grid_constant__ typename GemmType::Params const params)
{
GemmType gemm;
gemm.run_device(params);
}
/// Computes the maximum number of active blocks per multiprocessor
template <typename GemmType>
static int fused_gemm_maximum_active_blocks(int smem_capacity = -1)
{
CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()");
constexpr int smem_size = GemmType::kSmemSize;
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
cudaError_t result;
if (smem_size > (48 << 10))
{
result = cudaFuncSetAttribute(run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
return -1;
}
}
int max_active_blocks = -1;
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, run_global<GemmType>, GemmType::kThreadCount, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
} // namespace fused_moe
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
namespace fused_moe
{
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
int Stages_, Activation_Type activation_type_, typename Enable = void>
struct Fused_Moe_Kernel_routine_sm80;
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
int Stages_, Activation_Type activation_type_>
struct Fused_Moe_Kernel_routine_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_, Stages_,
activation_type_, std::enable_if_t<isGateActivation(activation_type_)>>
{
using KT = Fused_Moe_Kernel_traits_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_,
Stages_, activation_type_>;
using Params = Routine_Params<ElementInput_, ElementWeight_, ElementOutput_>;
CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params)
{
using X = cute::Underscore;
int const M = gemm_m;
int const N1 = params.gemm_n;
int const K1 = params.gemm_k;
bool const bias_is_broadcast = params.bias_is_broadcast;
int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]);
typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1;
typename KT::ElementWeight const* ptr_fc1_gate_
= params.ptr_fc1 + (2 * problem_index + 1) * N1 * K1; // TODO: we only focus on gated activation..
typename KT::ElementWeight const* ptr_fc1_
= params.ptr_fc1 + 2 * problem_index * N1 * K1; // TODO: we only focus on gated activation..
typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr)
? nullptr
: (bias_is_broadcast ? params.ptr_bias + 2 * problem_index * N1 : params.ptr_bias + 2 * row_jump * N1);
typename KT::ElementInput const* ptr_bias_gate_ = (params.ptr_bias == nullptr)
? nullptr
: (bias_is_broadcast ? params.ptr_bias + (2 * problem_index + 1) * N1
: params.ptr_bias + (2 * row_jump + 1) * N1);
typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1;
cute::Tensor mInput_mk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_input_)),
cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mfc1_gate_nk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_gate_)),
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mfc1_nk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_)),
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mBias_mn = cute::make_tensor(
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_)), cute::make_shape(M, N1),
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2,
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
cute::Tensor mBias_gate_mn = cute::make_tensor(
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_gate_)), cute::make_shape(M, N1),
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2,
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
cute::Tensor mOutput_mn
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput*>(ptr_output_)),
cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{}));
cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
cute::Tensor gfc1_gate_nk = cute::local_tile(mfc1_gate_nk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
cute::Tensor gBias_gate_mn = cute::local_tile(mBias_gate_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
return cute::make_tuple(gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn);
}
// be careful, m_idx will change when use another tile shape..
CUTE_DEVICE void run_routine(
Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m)
{
extern __shared__ char smem_[];
typename KT::SharedStorage& shared_storage = *reinterpret_cast<typename KT::SharedStorage*>(smem_);
int const thread_idx = threadIdx.x;
bool const bias_is_broadcast = params.bias_is_broadcast;
// gmem tensor partition ..
auto [gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn]
= gmem_tensor_init(problem_index, gemm_m, params);
int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk);
auto const n_tile_count = cute::size<2>(gfc1_gate_nk);
// smem tensor ..
cute::Tensor sInput = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage)
cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()),
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
cute::Tensor sfc1_gate_weight
= cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_gate_weight.data()),
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
cute::Tensor sO = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N)
// (1) first step, get the fc1_res and fc1_gate
// (1.1) get partition for gmem -> smem
cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k)
cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
cute::Tensor gfc1g = gfc1_gate_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
typename KT::GmemTiledCopyA gmem_tiled_copy_A;
typename KT::GmemTiledCopyB gmem_tiled_copy_B;
auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx);
auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx);
cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k)
cute::Tensor tInputsInput = gmem_thr_copy_A.partition_D(sInput); // (ACPY,ACPY_M,ACPY_K,Stage)
cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k)
cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
cute::Tensor tfc1ggfc1g = gmem_thr_copy_B.partition_S(gfc1g); // (BCPY,BCPY_N,BCPY_K,k)
cute::Tensor tfc1gsfc1g = gmem_thr_copy_B.partition_D(sfc1_gate_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
cute::Tensor tInputpInput
= cute::make_tensor<bool>(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)),
cute::Stride<cute::_1, cute::_0>{});
// Construct identity layout for sInput
cute::Tensor cInput = make_identity_tensor(
make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Set predicates for m bounds
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<0>(tInputpInput); ++m)
{
tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m
}
// (1.2) prefetch gmem -> smem
cute::clear(tInputsInput); // we don't need to clear tfc1sfc1..
auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0
int k_tile_count = cute::size<2>(gInput);
CUTLASS_PRAGMA_UNROLL
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
{
if (k_tile_count <= 0)
{
cute::clear(tInputpInput);
}
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, k_pipe));
// use copy_if
cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
tInputsInput(cute::_, cute::_, cute::_, k_pipe));
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1sfc1(cute::_, cute::_, cute::_, k_pipe));
cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1gsfc1g(cute::_, cute::_, cute::_, k_pipe));
cute::cp_async_fence();
k_tile_count--;
if (k_tile_count > 0)
{
++k_tile_iter;
}
}
// (1.3) get partition for rf
typename KT::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K)
cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
cute::Tensor tOrfc1g = thr_mma.partition_fragment_B(sfc1_gate_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
cute::Tensor accum
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
cute::Tensor accum_gate
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
cute::clear(accum);
cute::clear(accum_gate);
// checkout the shape
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum_gate)); // MMA_M
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum_gate)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum_gate)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1g)); // MMA_K
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma));
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma));
// (1.4)retiling the smem and rf for copy..
auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage)
cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K
auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma);
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage)
cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K)
cute::Tensor tOsfc1g = smem_thr_copy_B.partition_S(sfc1_gate_weight); // (CPY,CPY_N,CPY_K,Stage)
cute::Tensor tOrfc1g_copy_view = smem_thr_copy_B.retile_D(tOrfc1g); // (CPY,CPY_N,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1g) == cute::size<1>(tOrfc1g_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1g) == cute::size<2>(tOrfc1g_copy_view)); // CPY_K
// (1.5) mainloop
// Current pipe index in smem to read from
int smem_pipe_read = 0;
// Current pipe index in smem to write to
int smem_pipe_write = KT::Stages - 1;
cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
cute::Tensor tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput);
// prefetch register pipeline
if constexpr (K_BLOCK_MAX > 1)
{
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
// Prefetch the first rmem from the first k-tile
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}),
tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{}));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}),
tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{}));
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, cute::Int<0>{}),
tOrfc1g_copy_view(cute::_, cute::_, cute::Int<0>{}));
}
// k loop for mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
// Copy gmem to smem before computing gemm on each k-pipe
if (k_block == 0)
{
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy_if(gmem_tiled_copy_A, tInputpInput,
tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1gsfc1g(cute::_, cute::_, cute::_, smem_pipe_write));
cute::cp_async_fence();
if (k_tile_count - 1 > 0)
{
++k_tile_iter;
}
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
smem_pipe_write = smem_pipe_read;
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block),
accum);
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
});
}
// load tail
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
[&](auto WaitIndex)
{
k_tile_count--;
using WaitIndex_t = decltype(WaitIndex);
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
if (k_block == 0)
{
// only update smem_pipe_read
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block),
tOrfc1(cute::_, cute::_, k_block), accum);
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
});
});
// mma tail
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
// Thread-level register gemm for k_block
cute::gemm(
tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum);
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
});
// if (cute::thread0()) {
// cute::print(accum_gate(0, 0, 0));
// printf("\n");
// }
// (2) add bias if it has..
if (params.ptr_bias != nullptr)
{
cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
cute::Tensor gBias_gate = gBias_gate_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
cute::Tensor tOgBias = thr_mma.partition_C(gBias);
cute::Tensor tOgBiasg = thr_mma.partition_C(gBias_gate);
for (int i = 0; i < cute::size(accum); i++)
{
accum(i) += tOgBias(i);
accum_gate(i) += tOgBiasg(i);
}
}
// (3) calculate swiglu
using ActivationFn = typename KT::ActivationFn;
ActivationFn fn{};
CUTLASS_PRAGMA_UNROLL
for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++)
{
accum(temp_iter) = fn(accum_gate(temp_iter)) * accum(temp_iter);
}
// (4) push all the result to smem
// (4.1) convert result from ElementAccum to ElementInput
cute::Tensor temp_accum = util_convert_type<KT::ElementOutput>(accum);
// if (cute::thread0()) {
// cute::print(temp_accum(0, 0, 0));
// printf("\n");
// }
// (4.2) retile rf and smem for copy back..
auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
// cute::clear(sO);
cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum);
cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO);
// (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..)
cute::copy(smem_tiled_copy_O, taccumrO, taccumsO);
__syncthreads();
// (4.4) sO -> rO -> gO
typename KT::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
// auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); //
// remember, for all the threads in the same col, they have the same idx for bias..
cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx);
// cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row..
auto tOsO = gmem_thr_copy_O.partition_S(sO);
auto tOgO = gmem_thr_copy_O.partition_D(gO);
// auto tOgBias = gmem_thr_copy_O.partition_D(gBias);
cute::Tensor cOutput = cute::make_identity_tensor(
cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{})));
cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput);
cute::Tensor tOrO = cute::make_tensor<KT::ElementOutput>(cute::shape(tOgO));
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<1>(tOgO); ++m)
{
if (cute::get<0>(tOcO(0, m, 0)) < residue_m)
{
cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_));
}
}
}
};
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
int Stages_, Activation_Type activation_type_>
struct Fused_Moe_Kernel_routine_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_, Stages_,
activation_type_, std::enable_if_t<!isGateActivation(activation_type_)>>
{
using KT = Fused_Moe_Kernel_traits_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_,
Stages_, activation_type_>;
using Params = Routine_Params<ElementInput_, ElementWeight_, ElementOutput_>;
CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params)
{
using X = cute::Underscore;
int const M = gemm_m;
int const N1 = params.gemm_n;
int const K1 = params.gemm_k;
bool const bias_is_broadcast = params.bias_is_broadcast;
int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]);
typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1;
typename KT::ElementWeight const* ptr_fc1_ = params.ptr_fc1 + problem_index * N1 * K1;
typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr)
? nullptr
: (bias_is_broadcast ? params.ptr_bias + problem_index * N1 : params.ptr_bias + row_jump * N1);
typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1;
cute::Tensor mInput_mk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_input_)),
cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mfc1_nk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_)),
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mBias_mn = cute::make_tensor(
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_)), cute::make_shape(M, N1),
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1,
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
cute::Tensor mOutput_mn
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput*>(ptr_output_)),
cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{}));
cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
return cute::make_tuple(gInput_mk, gfc1_nk, gBias_mn, gOutput_mn);
}
// be careful, m_idx will change when use another tile shape..
CUTE_DEVICE void run_routine(
Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m)
{
extern __shared__ char smem_[];
typename KT::SharedStorage& shared_storage = *reinterpret_cast<typename KT::SharedStorage*>(smem_);
int const thread_idx = threadIdx.x;
bool const bias_is_broadcast = params.bias_is_broadcast;
// gmem tensor partition ..
auto [gInput_mk, gfc1_nk, gBias_mn, gOutput_mn] = gmem_tensor_init(problem_index, gemm_m, params);
int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk);
auto const n_tile_count = cute::size<2>(gfc1_nk);
// smem tensor ..
cute::Tensor sInput = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage)
cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()),
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
cute::Tensor sO = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N)
// (1) first step, get the fc1_res and fc1_gate
// (1.1) get partition for gmem -> smem
cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k)
cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
typename KT::GmemTiledCopyA gmem_tiled_copy_A;
typename KT::GmemTiledCopyB gmem_tiled_copy_B;
auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx);
auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx);
cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k)
cute::Tensor tInputsInput = gmem_thr_copy_A.partition_S(sInput); // (ACPY,ACPY_M,ACPY_K,Stage)
cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k)
cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
cute::Tensor tInputpInput
= cute::make_tensor<bool>(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)),
cute::Stride<cute::_1, cute::_0>{});
// Construct identity layout for sInput
cute::Tensor cInput = make_identity_tensor(
make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Set predicates for m bounds
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<0>(tInputpInput); ++m)
{
tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m
}
// (1.2) prefetch gmem -> smem
cute::clear(tInputsInput); // we don't need to clear tfc1sfc1..
auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0
int k_tile_count = cute::size<2>(gInput);
CUTLASS_PRAGMA_UNROLL
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
{
if (k_tile_count <= 0)
{
cute::clear(tInputpInput);
}
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, k_pipe));
// use copy_if
cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
tInputsInput(cute::_, cute::_, cute::_, k_pipe));
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1sfc1(cute::_, cute::_, cute::_, k_pipe));
cute::cp_async_fence();
k_tile_count--;
if (k_tile_count > 0)
{
++k_tile_iter;
}
}
// (1.3) get partition for rf
typename KT::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K)
cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
cute::Tensor accum
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
cute::clear(accum);
// checkout the shape
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma));
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma));
// (1.4)retiling the smem and rf for copy..
auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage)
cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K
auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma);
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage)
cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K
// (1.5) mainloop
// Current pipe index in smem to read from
int smem_pipe_read = 0;
// Current pipe index in smem to write to
int smem_pipe_write = KT::Stages - 1;
cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput);
// prefetch register pipeline
if constexpr (K_BLOCK_MAX > 1)
{
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
// Prefetch the first rmem from the first k-tile
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}),
tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{}));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}),
tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{}));
}
// k loop for mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
// Copy gmem to smem before computing gemm on each k-pipe
if (k_block == 0)
{
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy_if(gmem_tiled_copy_A, tInputpInput,
tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write));
cute::cp_async_fence();
if (k_tile_count - 1 > 0)
{
++k_tile_iter;
}
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
smem_pipe_write = smem_pipe_read;
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block),
accum);
});
}
// load tail
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
[&](auto WaitIndex)
{
k_tile_count--;
using WaitIndex_t = decltype(WaitIndex);
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
if (k_block == 0)
{
// only update smem_pipe_read
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block),
tOrfc1(cute::_, cute::_, k_block), accum);
});
});
// mma tail
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
// Thread-level register gemm for k_block
cute::gemm(
tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum);
});
// if (cute::thread0()) {
// cute::print(accum_gate(0, 0, 0));
// printf("\n");
// }
// (2) add bias if it has..
if (params.ptr_bias != nullptr)
{
cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
cute::Tensor tOgBias = thr_mma.partition_C(gBias);
for (int i = 0; i < cute::size(accum); i++)
{
accum(i) += tOgBias(i);
}
}
// (3) calculate swiglu
using ActivationFn = typename KT::ActivationFn;
ActivationFn fn{};
CUTLASS_PRAGMA_UNROLL
for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++)
{
accum(temp_iter) = fn(accum(temp_iter));
}
// (4) push all the result to smem
// (4.1) convert result from ElementAccum to ElementInput
cute::Tensor temp_accum = util_convert_type<KT::ElementOutput>(accum);
// if (cute::thread0()) {
// cute::print(temp_accum(0, 0, 0));
// printf("\n");
// }
// (4.2) retile rf and smem for copy back..
auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
// cute::clear(sO);
cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum);
cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO);
// (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..)
cute::copy(smem_tiled_copy_O, taccumrO, taccumsO);
__syncthreads();
// (4.4) sO -> rO -> gO
typename KT::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
// auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); //
cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx);
auto tOsO = gmem_thr_copy_O.partition_S(sO);
auto tOgO = gmem_thr_copy_O.partition_D(gO);
cute::Tensor cOutput = cute::make_identity_tensor(
cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{})));
cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput);
cute::Tensor tOrO = cute::make_tensor<KT::ElementOutput>(cute::shape(tOgO));
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<1>(tOgO); ++m)
{
if (cute::get<0>(tOcO(0, m, 0)) < residue_m)
{
cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_));
}
}
}
};
} // namespace fused_moe
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass_extensions/epilogue_helpers.h>
#include <cutlass_extensions/gemm/kernel/moe_cute_util.cuh>
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
namespace fused_moe
{
template <typename ElementInput, typename ElementWeight, typename ElementOutput>
struct Routine_Arguments
{
ElementInput* ptr_input{};
ElementWeight* ptr_fc1{};
ElementInput* ptr_bias{};
ElementOutput* ptr_output{};
int64_t const* total_tokens_including_expert{};
int gemm_n{};
int gemm_k{};
int num_expert{};
bool bias_is_broadcast{};
};
template <typename ElementInput, typename ElementWeight, typename ElementOutput>
struct Routine_Params
{
ElementInput* ptr_input{};
ElementWeight* ptr_fc1{};
ElementInput* ptr_bias{};
ElementOutput* ptr_output{};
int64_t const* total_tokens_including_expert{};
int gemm_n{};
int gemm_k{};
int num_expert{};
bool bias_is_broadcast{};
};
enum class Activation_Type
{
Gelu = 0,
Relu,
Silu,
Swiglu,
Geglu,
Identity,
InvalidType
};
constexpr bool isGateActivation(Activation_Type const& activation_type)
{
return activation_type == Activation_Type::Swiglu || activation_type == Activation_Type::Geglu;
}
template <typename CutlassExtensionEpilogueTag>
constexpr Activation_Type EpilogueRouting(bool /*is_gate*/)
{
return Activation_Type::InvalidType;
}
template <>
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefault>(bool /*is_gate*/)
{
return Activation_Type::Identity;
}
template <>
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultReLU>(bool /*is_gate*/)
{
return Activation_Type::Relu;
}
template <>
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultSilu>(bool is_gate)
{
return is_gate ? Activation_Type::Swiglu : Activation_Type::Silu;
}
template <>
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultFtGelu>(bool is_gate)
{
return is_gate ? Activation_Type::Geglu : Activation_Type::Gelu;
}
/* fusing all three kernels has many limitations. This is the simpler version. Just fuse first two kernels..*/
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
int Stages_, Activation_Type activation_type>
struct Fused_Moe_Kernel_traits_sm80
{
using ElementInput = ElementInput_;
using ElementWeight = ElementWeight_;
using ElementAccum = float;
using ElementOutput = ElementOutput_;
using index_t = uint32_t;
static_assert(TileM_ % 16 == 0);
static_assert(TileN_ % 32 == 0);
static_assert(TileK_ % 32 == 0);
static constexpr int Stages = Stages_;
static constexpr int kTileM = TileM_;
static constexpr int kTileN = TileN_;
static constexpr int kTileK = (kTileM > 16) ? (TileK_) : (TileK_ >= 64 ? TileK_ : 64);
// tile shape
using TileShape = cute::Shape<cute::Int<kTileM>, cute::Int<kTileN>, cute::Int<kTileK>>;
static constexpr int kWarpsCount = 4;
static constexpr int kThreadCount = kWarpsCount * 32;
// MMA atom arch and layout
using MMA_Atom_Arch = std::conditional_t<std::is_same_v<ElementInput, cutlass::half_t>,
cute::MMA_Atom<cute::SM80_16x8x16_F32F16F16F32_TN>, cute::MMA_Atom<cute::SM80_16x8x16_F32BF16BF16F32_TN>>;
// using ValLayoutMNK = cute::Layout<cute::Shape<cute::_1, cute::_2, cute::_1>>;
using ThreadLayoutMNK
= std::conditional_t<kTileM == 16, cute::Layout<cute::Shape<cute::_1, cute::Int<kWarpsCount / 1>, cute::_1>>,
cute::Layout<cute::Shape<cute::_2, cute::Int<kWarpsCount / 2>, cute::_1>>>;
using ValLayoutMNK = std::conditional_t<kTileM == 16, cute::Tile<cute::_16, cute::_64, cute::_16>,
cute::Tile<cute::_32, cute::_32, cute::_16>>;
using TiledMma = cute::TiledMMA<MMA_Atom_Arch, ThreadLayoutMNK,
ValLayoutMNK>; // 32x32x16 or 16x64x16 MMA for LDSM if kWarp = 4
static constexpr int kAlignment = 8;
static constexpr int kBlcokKSmem = (kTileM == 16) ? 64 : 32;
// A memory copy operand
using DefaultOperandA
= DefaultGemm_TensorOpSm80_OperandA<ElementInput, cutlass::layout::RowMajor, kAlignment, kBlcokKSmem>;
using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom;
using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom;
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
// B memory copy operand
using DefaultOperandB
= DefaultGemm_TensorOpSm80_OperandB<ElementWeight, cutlass::layout::ColumnMajor, kAlignment, kBlcokKSmem>;
using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom;
using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom;
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
// Output memory copy operand
using SmemLayoutAtomO = SmemLayoutAtomA;
using SmemCopyAtomO = cute::Copy_Atom<cute::DefaultCopy, ElementOutput>;
static constexpr int kGmemElementPerLoad = sizeof(cute::uint128_t) / sizeof(ElementOutput);
static constexpr int kGmemTrheadsPerRow = kBlcokKSmem / kGmemElementPerLoad;
using GmemLayoutAtomO
= cute::Layout<cute::Shape<cute::Int<kThreadCount / kGmemTrheadsPerRow>, cute::Int<kGmemTrheadsPerRow>>,
cute::Stride<cute::Int<kGmemTrheadsPerRow>, cute::_1>>;
using GmemTiledCopyO = decltype(cute::make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, ElementOutput>{},
GmemLayoutAtomO{}, cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
static_assert(cute::rank(SmemLayoutAtomA{}) == 2);
static_assert(cute::size<0>(TileShape{}) % cute::size<0>(SmemLayoutAtomA{}) == 0); // M
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomA{}) == 0); // K
static_assert(cute::rank(SmemLayoutAtomB{}) == 2);
static_assert(cute::size<1>(TileShape{}) % cute::size<0>(SmemLayoutAtomB{}) == 0); // N
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomB{}) == 0); // K
using SmemLayoutA = decltype(cute::tile_to_shape(SmemLayoutAtomA{},
cute::make_shape(
cute::shape<0>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_M, BLK_K, Stages
using SmemLayoutB = decltype(cute::tile_to_shape(SmemLayoutAtomB{},
cute::make_shape(
cute::shape<1>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_N, BLK_K, Stages
using SmemLayoutO = decltype(cute::tile_to_shape(
SmemLayoutAtomO{}, cute::make_shape(cute::shape<0>(TileShape{}), cute::shape<1>(TileShape{})))); // BLK_M, BLK_N
// we need at least 2 stages..
static_assert(Stages >= 2);
struct SharedStorageNormal : cute::aligned_struct<128>
{
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutA>> smem_input;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_weight;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct SharedStorageGate : cute::aligned_struct<128>
{
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutA>> smem_input;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_gate_weight;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_weight;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutO>> smem_o;
};
using SharedStorage = std::conditional_t<isGateActivation(activation_type), SharedStorageGate, SharedStorageNormal>;
using ActivationFn = std::conditional_t<activation_type == Activation_Type::Gelu
|| activation_type == Activation_Type::Geglu,
cutlass::epilogue::thread::GELU<float>,
std::conditional_t<activation_type == Activation_Type::Relu, cutlass::epilogue::thread::ReLU<float>,
std::conditional_t<activation_type == Activation_Type::Silu || activation_type == Activation_Type::Swiglu,
cutlass::epilogue::thread::SiLu<float>, cutlass::epilogue::thread::Identity<float>>>>;
static constexpr int kSmemSize = static_cast<int>(sizeof(SharedStorage));
static constexpr bool can_implement(int const avaliable_smem_size)
{
return avaliable_smem_size > kSmemSize;
}
// #endif
};
} // namespace fused_moe
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief Scheduler for grouped GEMM
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
#include "cutlass/matrix_coord.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_, int PrefetchTileCount, int ThreadCount,
bool Transposed = false>
struct GemmMoeProblemVisitor
: public MoeProblemVisitor<detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>, ThreadblockShape,
GroupScheduleMode_, PrefetchTileCount, ThreadCount>
{
static bool const kTransposed = Transposed;
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
using Base
= MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
using Params = typename Base::Params;
using SharedStorage = typename Base::SharedStorage;
//
// Methods
//
CUTLASS_DEVICE
GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
: Base(params_, shared_storage_, block_idx)
{
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel
{
////////////////////////////////////////////////////////////////////////////////
/*
* Stateless universal device GEMM kernel type that treats GEMM as
* a composition of a collective mainloop and a collective epilogue.
*
* Supports both the 2.x and 3.x APIs based on whether the first type is
* a cute::tuple<> or not.
* 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h
* 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp
*
* In the following declaration, the name preceding the 'Or' refers to
* 3.x API type argument order, and the name succeeding the 'Or' refers to
* 2.x API type argument order. Template arguments without two names
* belong to the 3.x API only.
**/
template <class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l)
class CollectiveMainloopOrEpilogue_, class CollectiveEpilogueOrThreadblockSwizzle_, class TileScheduler_ = void,
class Enable = void>
class GemmUniversalGated;
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel
////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp"
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp"
////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief GEMM kernel to support the epilogue visitor model
for customized softmax partial reduction epilogue fusion.
This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once
its usage has been stabilized. For now, it is included in this example to demonstrate
some basic output fusion options.
original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h"
namespace tk = tensorrt_llm::common;
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmWithEpilogueVisitor
{
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueVisitor = typename Epilogue::Visitor;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using TensorRefA = TensorRef<ElementA, LayoutA>;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using TensorRefB = TensorRef<ElementB, LayoutB>;
using ElementCompute = typename EpilogueVisitor::ElementCompute;
using LayoutAlphaCol = cutlass::layout::RowMajor;
using LayoutAlphaRow = cutlass::layout::ColumnMajor;
using TensorRefAlphaCol = TensorRef<ElementCompute, LayoutAlphaCol>;
using TensorRefAlphaRow = TensorRef<ElementCompute, LayoutAlphaRow>;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename Epilogue::Layout;
using TensorRefC = TensorRef<ElementC, LayoutC>;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
using EpilogueOutputOp =
typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment
= const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
//
// Structures
//
/// Argument structure
struct Arguments
{
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
TensorRefA ref_A;
TensorRefB ref_B;
tk::QuantMode quant_option;
TensorRefAlphaCol ref_alpha_col;
TensorRefAlphaRow ref_alpha_row;
TensorRefC ref_C;
TensorRefC ref_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_D;
typename EpilogueVisitor::Arguments epilogue_visitor;
//
// Methods
//
Arguments()
: mode(GemmUniversalMode::kGemm)
, batch_count(1)
{
}
/// constructs an arguments structure
Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_,
TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_,
TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_,
int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_)
: mode(mode_)
, problem_size(problem_size_)
, batch_count(batch_count_)
, ref_A(ref_A_)
, ref_B(ref_B_)
, quant_option(quant_option_)
, ref_alpha_col(ref_alpha_col_)
, ref_alpha_row(ref_alpha_row_)
, ref_C(ref_C_)
, ref_D(ref_D_)
, batch_stride_A(batch_stride_A_)
, batch_stride_B(batch_stride_B_)
, batch_stride_D(0)
, epilogue_visitor(epilogue_visitor_)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params
{
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col;
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row;
typename EpilogueVisitor::OutputTileIterator::Params params_C;
typename EpilogueVisitor::OutputTileIterator::Params params_D;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void* ptr_A;
void* ptr_B;
tk::QuantMode quant_option;
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col;
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row;
ElementC* ptr_C;
ElementC* ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
typename EpilogueVisitor::Params epilogue_visitor;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: swizzle_log_tile(0)
, params_A(0)
, params_B(0)
, params_alpha_col(0)
, params_C(0)
, params_D(0)
, batch_count(0)
, gemm_k_size(0)
, mode(cutlass::gemm::GemmUniversalMode::kGemm)
, ptr_A(nullptr)
, ptr_B(nullptr)
, ptr_alpha_col(nullptr)
, ptr_alpha_row(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, batch_stride_A(0)
, batch_stride_B(0)
{
}
Params(
Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_)
: problem_size(args.problem_size)
, swizzle_log_tile(0)
, params_A(args.ref_A.layout())
, params_B(args.ref_B.layout())
, params_alpha_col(args.ref_alpha_col.layout())
, params_alpha_row(args.ref_alpha_col.layout())
, params_C(args.ref_C.layout())
, params_D(args.ref_D.layout())
, mode(args.mode)
, batch_count(args.batch_count)
, gemm_k_size(args.problem_size.k())
, ptr_A(args.ref_A.data())
, ptr_B(args.ref_B.data())
, quant_option(args.quant_option)
, ptr_alpha_col(args.ref_alpha_col.data())
, ptr_alpha_row(args.ref_alpha_row.data())
, ptr_C(args.ref_C.data())
, ptr_D(args.ref_D.data())
, batch_stride_A(args.batch_stride_A)
, batch_stride_B(args.batch_stride_B)
, epilogue_visitor(args.epilogue_visitor)
{
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
{
int const kAlignK
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size)
{
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
}
};
/// Shared memory storage structure
union SharedStorage
{
typename Mma::SharedStorage main_loop;
struct
{
typename Epilogue::SharedStorage epilogue;
typename EpilogueVisitor::SharedStorage visitor;
} epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmWithEpilogueVisitor() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
{
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (platform::is_same<LayoutA, layout::RowMajor>::value)
{
isAMisaligned = problem_size.k() % kAlignmentA;
}
else if (platform::is_same<LayoutA, layout::ColumnMajor>::value)
{
isAMisaligned = problem_size.m() % kAlignmentA;
}
else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value)
{
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value)
{
isBMisaligned = problem_size.n() % kAlignmentB;
}
else if (platform::is_same<LayoutB, layout::ColumnMajor>::value)
{
isBMisaligned = problem_size.k() % kAlignmentB;
}
else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value)
{
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value)
{
isCMisaligned = problem_size.n() % kAlignmentC;
}
else if (platform::is_same<LayoutC, layout::ColumnMajor>::value)
{
isCMisaligned = problem_size.m() % kAlignmentC;
}
else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value)
{
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned)
{
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned)
{
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned)
{
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
static Status can_implement(Arguments const& args)
{
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
{
return 0;
}
#define SPLIT_K_ENABLED 1
/// Executes one GEMM
CUTLASS_DEVICE
void run_kernel_(Params const& params, SharedStorage& shared_storage)
{
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
{
return;
}
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA* ptr_A = static_cast<ElementA*>(params.ptr_A);
ElementB* ptr_B = static_cast<ElementB*>(params.ptr_B);
#if SPLIT_K_ENABLED
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel)
{
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k())
{
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
}
else if (params.mode == GemmUniversalMode::kBatched)
{
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
}
else if (params.mode == GemmUniversalMode::kArray)
{
ptr_A = static_cast<ElementA* const*>(params.ptr_A)[threadblock_tile_offset.k()];
ptr_B = static_cast<ElementB* const*>(params.ptr_B)[threadblock_tile_offset.k()];
}
#endif
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
//
// Construct the epilogue visitor
//
EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor,
params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C,
params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C,
params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m());
if (params.mode == GemmUniversalMode::kGemm)
{
// Indicate which position in a serial reduction the output operator is currently updating
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray)
{
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
}
// Construct the epilogue
Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(epilogue_visitor, accumulators);
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
{
if constexpr (platform::is_same<ArchTag, CompilationArch>::value)
{
run_kernel_(params, shared_storage);
}
else
{
CUTLASS_NOT_IMPLEMENTED();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750)
run_kernel<arch::Sm72>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 900)
// TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels.
run_kernel<arch::Sm80>(params, shared_storage);
#else
static_assert(
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
#endif
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
to be consumed by CUTLASS.
Note that for int4, ThreadBlockK MUST be 64.
*/
#pragma once
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/platform/platform.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
namespace cutlass
{
namespace gemm
{
namespace kernel
{
template <typename TypeA, typename TypeB, typename Arch, typename Enable = void>
struct LayoutDetailsB
{
};
// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
// TODO - Switch this to column major for weights since gemms should be more performant.
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA>
struct LayoutDetailsB<TypeA, cutlass::float_e4m3_t, arch::Sm89>
{
static constexpr int ThreadblockK = 64;
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
public:
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<cutlass::float_e4m3_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
// for fast accumulation
// using Operator = cutlass::arch::OpMultiplyAddFastAccum;
};
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
// which signals that we want to dequantize after loading from smem.
template <typename TypeA, typename Arch>
struct LayoutDetailsB < TypeA,
uint8_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
public:
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint8_t>::value;
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB < TypeA,
uint4b_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
public:
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint4b_t>::value;
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cute/algorithm/copy.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/layout/layout.h>
#include <cutlass/numeric_conversion.h>
template <typename Element, typename Layout, int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA;
template <typename Element, typename Layout, int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB;
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 64>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 64>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
/// Operand A - Column-major (M-major)
template <int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, 8, SizeK>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::half_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
};
template <int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, 8, SizeK>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::bfloat16_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
};
// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands
// Operand B - Column-Major (K-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
{
};
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
{
};
// Operand B - Row-Major (N-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{
};
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{
};
//
// F16: 128-by-128-by-32 (small k-block)
//
/// Operand A - Row-major (K-Major)
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 32>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 32>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
template <typename To_type, typename Engine, typename Layout>
CUTE_DEVICE auto util_convert_type(cute::Tensor<Engine, Layout> const& tensor)
{
using From_type = typename Engine::value_type;
constexpr int numel = decltype(cute::size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// HACK: this requires tensor to be "contiguous"
auto frag = convert_op(*reinterpret_cast<cutlass::Array<From_type, numel> const*>(tensor.data()));
return cute::make_tensor(cute::make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template <typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
CUTE_DEVICE void util_copy(
TiledCopy const& tiled_copy, cute::Tensor<Engine0, Layout0> const& S, cute::Tensor<Engine1, Layout1>& D)
{
CUTE_STATIC_ASSERT_V(cute::rank(S) == cute::Int<3>{});
CUTE_STATIC_ASSERT_V(cute::rank(D) == cute::Int<3>{});
CUTE_STATIC_ASSERT_V(cute::size<0>(S) == cute::size<0>(D));
CUTE_STATIC_ASSERT_V(cute::size<1>(S) == cute::size<1>(D));
CUTE_STATIC_ASSERT_V(cute::size<2>(S) == cute::size<2>(D));
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<1>(S); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < cute::size<2>(S); ++k)
{
cute::copy(tiled_copy, S(cute::_, m, k), D(cute::_, m, k));
}
}
}
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms.
// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global.
template <typename...>
using void_t = void;
template <typename Mma, typename = void>
struct use_dq_gemm : platform::false_type
{
};
template <typename Mma>
struct use_dq_gemm<Mma, void_t<typename Mma::IteratorScale>> : platform::true_type
{
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
/// arch.
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform
>
struct MoeFCGemm
{
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static bool const kTransposed = false;
// Optional transpose
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
kTransposed>;
// Public-facing type definitions related to operand element type, layout, and complex conjugate
// operation. Must interact with the 'kTransposed' notion.
static_assert(!kTransposed, "Transpose problem not supported");
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename MapArguments::LayoutC;
using ElementScale = ElementC;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = MapArguments::kAlignmentA;
static int const kAlignmentB = MapArguments::kAlignmentB;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor
= GemmMoeProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
//
// Structures
//
/// Argument structure
struct Arguments
{
//
// Data members
//
int problem_count;
int threadblock_count;
int group_size;
typename EpilogueOutputOp::Params output_op;
ElementA* ptr_A;
ElementB* ptr_B;
ElementScale* weight_scales;
ElementC* ptr_C;
ElementC* ptr_D;
bool C_is_broadcast;
int64_t const* total_tokens_including_expert;
int64_t gemm_n;
int64_t gemm_k;
// Only used by device-level operator
GemmCoord* host_problem_sizes;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments()
: problem_count(0)
, threadblock_count(0)
, ptr_A(nullptr)
, ptr_B(nullptr)
, weight_scales(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, total_tokens_including_expert(nullptr)
, gemm_n(0)
, gemm_k(0)
, host_problem_sizes(nullptr)
, C_is_broadcast{true}
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op,
ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C,
bool C_is_broadcast, ElementC* ptr_D, int64_t const* total_tokens_including_expert, int64_t gemm_n,
int64_t gemm_k, GemmCoord* host_problem_sizes = nullptr)
: problem_count(problem_count)
, threadblock_count(threadblock_count)
, group_size(group_size)
, output_op(output_op)
, ptr_A(const_cast<ElementA*>(ptr_A))
, ptr_B(const_cast<ElementB*>(ptr_B))
, weight_scales(const_cast<ElementScale*>(weight_scales))
, ptr_C(const_cast<ElementC*>(ptr_C))
, C_is_broadcast{C_is_broadcast}
, ptr_D(ptr_D)
, total_tokens_including_expert(total_tokens_including_expert)
, gemm_n(gemm_n)
, gemm_k(gemm_k)
, host_problem_sizes(nullptr)
{
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
{
assert(weight_scales);
}
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params
{
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
int group_size;
bool C_is_broadcast;
typename EpilogueOutputOp::Params output_op;
ElementA* ptr_A;
ElementB* ptr_B;
ElementScale* weight_scales;
ElementC* ptr_C;
ElementC* ptr_D;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: ptr_A(nullptr)
, ptr_B(nullptr)
, weight_scales(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, C_is_broadcast(true)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
: problem_visitor(
args.total_tokens_including_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count)
, threadblock_count(args.threadblock_count)
, group_size(args.group_size)
, output_op(args.output_op)
, ptr_A(args.ptr_A)
, ptr_B(args.ptr_B)
, weight_scales(args.weight_scales)
, ptr_C(args.ptr_C)
, ptr_D(args.ptr_D)
, C_is_broadcast(args.C_is_broadcast)
{
}
CUTLASS_HOST_DEVICE
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
{
problem_visitor = typename ProblemVisitor::Params(args.total_tokens_including_expert, args.gemm_n,
args.gemm_k, args.problem_count, workspace, tile_count);
threadblock_count = args.threadblock_count;
output_op = args.output_op;
ptr_A = args.ptr_A;
ptr_B = args.ptr_B;
weight_scales = args.weight_scales;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
C_is_broadcast = args.C_is_broadcast;
}
};
/// Shared memory storage structure
union SharedStorage
{
typename ProblemVisitor::SharedStorage problem_visitor;
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
MoeFCGemm() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
{
return Status::kSuccess;
}
static Status can_implement(Arguments const& args)
{
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
{
if (args.weight_scales == nullptr)
{
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t");
return Status::kInvalid;
}
}
else if (args.weight_scales != nullptr)
{
CUTLASS_TRACE_HOST(
"MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t");
return Status::kInvalid;
}
else if (args.group_size != args.gemm_k)
{
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)");
return Status::kInvalid;
}
// Handle the case the input is too short
else if (args.gemm_n < Mma::IteratorB::AccessType::kElements)
{
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment");
return Status::kInvalid;
}
return Status::kSuccess;
}
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
{
return 0;
}
CUTLASS_DEVICE
void run_kernel_(Params const& params, SharedStorage& shared_storage)
{
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
//
// Problem visitor.
//
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
const int64_t gemm_k = params.problem_visitor.gemm_k;
const int64_t gemm_n = params.problem_visitor.gemm_n;
int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits<ElementB>::value;
// Outer 'persistent' loop to iterate over tiles
int loop = 0;
while (problem_visitor.next_tile())
{
loop++;
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
cutlass::gemm::GemmCoord threadblock_offset(
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0);
// Load element pointers. Exchange pointers and strides if working on the transpose
const int64_t rows_to_jump
= problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
ElementA* ptr_A = reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
typename LayoutA::LongIndex ldm_A = gemm_k;
char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix;
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
typename LayoutB::LongIndex ldm_B
= platform::is_same<layout::RowMajor, LayoutB>::value ? gemm_n : gemm_k * kInterleave;
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B,
{problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
auto CreateMMA = [&]()
{
if constexpr (use_dq_gemm<Mma>::value)
return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
else
return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
};
Mma mma = CreateMMA();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n();
if constexpr (use_dq_gemm<Mma>::value)
{
const MatrixCoord scale_extent = {1, problem_size.n()};
typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()),
weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale);
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
}
else
{
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
}
//
// Epilogue
//
ElementC* ptr_C = reinterpret_cast<ElementC*>(params.ptr_C)
+ (params.C_is_broadcast ? problem_idx : rows_to_jump) * gemm_n;
ElementC* ptr_D = reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
// lora need to set as layout_C(gemm_n)
LayoutC layout_C = params.C_is_broadcast ? LayoutC(0) : LayoutC(gemm_n);
LayoutC layout_D(gemm_n);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn());
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn());
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
if constexpr (platform::is_same<EpilogueOutputOp,
cutlass::epilogue::thread::LinearCombination<typename EpilogueOutputOp::ElementOutput,
EpilogueOutputOp::kCount, typename EpilogueOutputOp::ElementAccumulator,
typename EpilogueOutputOp::ElementCompute, EpilogueOutputOp::kScale,
EpilogueOutputOp::kRound>>::value)
{
EpilogueOutputOp output_op(params.output_op, problem_idx);
epilogue(output_op, iterator_D, accumulators, iterator_C);
}
else
{
EpilogueOutputOp output_op(params.output_op);
epilogue(output_op, iterator_D, accumulators, iterator_C);
}
// Next tile
problem_visitor.advance(gridDim.x);
}
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
{
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
{
run_kernel_(params, shared_storage);
}
else
{
CUTLASS_NOT_IMPLEMENTED();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 890) && (__CUDA_ARCH__ < 900)
constexpr bool isFp8 = platform::is_same<ElementA, cutlass::float_e4m3_t>::value
|| platform::is_same<ElementA, cutlass::float_e5m2_t>::value;
if constexpr (isFp8)
{
run_kernel<arch::Sm89>(params, shared_storage);
}
else
{ // reuse sm80 kernel for other types, align with dispatchToArch
run_kernel<arch::Sm80>(params, shared_storage);
}
#elif (__CUDA_ARCH__ >= 900)
run_kernel<arch::Sm80>(params, shared_storage);
#else
static_assert(
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
#endif
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief Base scheduler for grouped problems, using MoE
*/
#pragma once
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ProblemSizeHelper, typename ThreadblockShape_>
struct BaseMoeProblemVisitor
{
using ThreadblockShape = ThreadblockShape_;
struct ProblemInfo
{
static int32_t const kNoPrefetchEntry = -1;
int32_t problem_idx;
int32_t problem_start;
CUTLASS_DEVICE
ProblemInfo()
: problem_idx(kNoPrefetchEntry)
, problem_start(kNoPrefetchEntry)
{
}
CUTLASS_DEVICE
ProblemInfo(int32_t problem_idx_, int32_t problem_start_)
: problem_idx(problem_idx_)
, problem_start(problem_start_)
{
}
};
struct Params
{
int64_t const* last_row_for_problem;
int64_t gemm_n;
int64_t gemm_k;
int32_t problem_count;
void const* workspace;
int32_t tile_count;
//
// Methods
//
/// Ctor
CUTLASS_HOST_DEVICE
Params()
: last_row_for_problem(nullptr)
, gemm_n(0)
, gemm_k(0)
, problem_count(0)
, workspace(nullptr)
, tile_count(0)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count,
void const* workspace = nullptr, int32_t tile_count = 0)
: last_row_for_problem(last_row_for_problem)
, gemm_n(gemm_n)
, gemm_k(gemm_k)
, problem_count(problem_count)
, workspace(workspace)
, tile_count(tile_count)
{
}
};
Params const& params;
int32_t tile_idx;
int32_t problem_tile_start;
int32_t problem_idx;
//
// Methods
//
CUTLASS_DEVICE
BaseMoeProblemVisitor(Params const& params_, int32_t block_idx)
: params(params_)
, tile_idx(block_idx)
, problem_tile_start(0)
, problem_idx(0)
{
}
/// Get the grid shape
CUTLASS_HOST_DEVICE
static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem)
{
return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1);
}
/// Gets the global tile index
CUTLASS_HOST_DEVICE
int32_t tile_index() const
{
return tile_idx;
}
/// Gets the index of the problem
CUTLASS_HOST_DEVICE
int32_t problem_index() const
{
return problem_idx;
}
CUTLASS_HOST_DEVICE
int32_t threadblock_idx() const
{
return tile_idx - problem_tile_start;
}
CUTLASS_DEVICE
void advance(int32_t grid_size)
{
tile_idx += grid_size;
}
CUTLASS_HOST_DEVICE
static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem)
{
ProblemSizeHelper::possibly_transpose_problem(problem);
}
/// Returns the problem size for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size() const
{
return problem_size(problem_idx);
}
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size(int idx) const
{
const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1];
const int64_t current_problem_row = params.last_row_for_problem[idx];
const int64_t gemm_m = current_problem_row - prev_problem_row;
GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k));
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
CUTLASS_HOST_DEVICE
static int32_t tile_count(cutlass::gemm::GemmCoord const& grid)
{
return ProblemSizeHelper::tile_count(grid);
}
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count)
{
int32_t total_tiles = 0;
for (int32_t i = 0; i < problem_count; ++i)
{
auto problem = host_problem_sizes_ptr[i];
possibly_transpose_problem(problem);
auto grid = grid_shape(problem);
total_tiles += tile_count(grid);
}
return total_tiles;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename ProblemSizeHelper, typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_,
int PrefetchTileCount, int ThreadCount>
struct MoeProblemVisitor;
/////////////////////////////////////////////////////////////////////////////////////////////////
// ProblemVisitor that performs all scheduling on device
//
template <typename ProblemSizeHelper, typename ThreadblockShape, int PrefetchTileCount, int ThreadCount>
struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode::kDeviceOnly, PrefetchTileCount,
ThreadCount> : public BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>
{
using Base = BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>;
using Params = typename Base::Params;
static int const kThreadCount = ThreadCount;
static bool const kRequiresPrecomputation = false;
static int const kThreadsPerWarp = 32;
struct SharedStorage
{
};
// Final tile of the problem loaded by this thread. Each thread will hold
// a separate value.
int32_t problem_ending_tile;
SharedStorage& shared_storage;
//
// Methods
//
CUTLASS_DEVICE
MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
: Base(params_, block_idx)
, problem_ending_tile(0)
, shared_storage(shared_storage_)
{
this->problem_idx = -1 * kThreadsPerWarp;
this->problem_tile_start = 0;
}
CUTLASS_DEVICE
bool next_tile()
{
// Check whether the tile to compute is within the range of the current problem.
int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp);
if (this->tile_idx < problem_tile_end)
{
return true;
}
// Check whether the tile to compute is within the current group of problems fetched by the warp.
// The last tile for this group is the final tile of the problem held by the final thread in the warp.
int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
// Keep the starting problem for this group in `problem_idx`. This is done to reduce
// register pressure. The starting problem for this group is simply the first problem
// in the group most recently fetched by the warp.
int32_t& group_problem_start = this->problem_idx;
group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp;
// Keep the starting tile for this group in `problem_tile_start`. This is done to reduce
// register pressure.
int32_t& group_tile_start = this->problem_tile_start;
// Each thread in the warp processes a separate problem to advance until
// reaching a problem whose starting tile is less less than tile_idx.
while (group_tile_end <= this->tile_idx)
{
group_problem_start += kThreadsPerWarp;
if (group_problem_start > this->params.problem_count)
{
return false;
}
// Since `group_tile_start` is a reference to `this->problem_tile_start`, this
// also sets `this->problem_tile_start`. The fact that `this->problem_tile_start`
// is also set here is used later in `next_tile`.
group_tile_start = group_tile_end;
int lane_idx = threadIdx.x % kThreadsPerWarp;
int32_t lane_problem = group_problem_start + lane_idx;
// Compute the number of tiles in the problem assigned to each thread.
problem_ending_tile = 0;
if (lane_problem < this->params.problem_count)
{
cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem);
cutlass::gemm::GemmCoord grid = this->grid_shape(problem);
problem_ending_tile = this->tile_count(grid);
}
// Compute a warp-wide inclusive prefix sum to compute the ending tile index of
// each thread's problem.
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kThreadsPerWarp; i <<= 1)
{
int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i);
if (lane_idx >= i)
{
problem_ending_tile += val;
}
}
// The total tile count for this group is now in the final position of the prefix sum
int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
problem_ending_tile += group_tile_start;
group_tile_end += tiles_in_group;
}
// The next problem to process is the first one that does not have ending tile position
// that is greater than or equal to tile index.
int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx));
this->problem_idx = group_problem_start + problem_idx_in_group;
// The starting tile for this problem is the ending tile of the previous problem. In cases
// where `problem_idx_in_group` is the first problem in the group, we do not need to reset
// `problem_tile_start`, because it is set to the previous group's ending tile in the while
// loop above.
if (problem_idx_in_group > 0)
{
this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1);
}
return true;
}
static size_t get_workspace_size(
cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count)
{
return 0;
}
static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count,
int32_t block_count, void* host_workspace_ptr)
{
}
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment