Unverified Commit 3ee62235 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

revert the MoE dependence (#3230)

parent 9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/permute.h"
#include "splitk_gemm_grouped.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Whether the schedule of problems to visit has been precomputed
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly,
/// Operation performed by GEMM
typename Operator = typename device::DefaultGemmConfiguration<OperatorClass, ArchTag, ElementA_, ElementB_,
ElementC_, ElementAccumulator>::Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Permute result D
typename PermuteDLayout = layout::NoPermute,
///
typename Enable = void>
struct DefaultSplitkGemmGrouped;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Real-valued GEMM kernels
//
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Whether the schedule of problems to visit has been precomputed
GroupScheduleMode GroupScheduleMode_,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Permute result D
typename PermuteDLayout>
struct DefaultSplitkGemmGrouped<ElementA, LayoutA,
ComplexTransform::kNone, // transform A
kAlignmentA, ElementB, LayoutB,
ComplexTransform::kNone, // transform B
kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape,
InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, GroupScheduleMode_, Operator, SharedMemoryClear,
PermuteDLayout, typename platform::enable_if<!cutlass::is_complex<ElementAccumulator>::value>::type>
{
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
static bool const kInternalTranspose = platform::is_same<LayoutC, layout::ColumnMajor>::value;
using MapArguments = kernel::detail::MapArguments<ElementA, LayoutA, ComplexTransform::kNone, kAlignmentA, ElementB,
LayoutB, ComplexTransform::kNone, kAlignmentB, LayoutC, kInternalTranspose>;
// Define the default GEMM kernel
using DefaultGemmKernel = typename kernel::DefaultGemm<typename MapArguments::ElementA,
typename MapArguments::LayoutA, MapArguments::kAlignmentA, typename MapArguments::ElementB,
typename MapArguments::LayoutB, MapArguments::kAlignmentB, ElementC, typename MapArguments::LayoutC,
ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
ThreadblockSwizzle, Stages, true, Operator, SharedMemoryClear, false, /*GatherA*/
false, /*GatherB*/
false, /*ScatterD*/
PermuteDLayout>::GemmKernel;
/// Define the kernel in terms of the default kernel
using GemmKernel = kernel::SplitkGemmGrouped<typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue,
ThreadblockSwizzle, GroupScheduleMode_, kInternalTranspose>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include <type_traits>
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail
{
template <typename>
inline constexpr bool dependent_false_v = false;
}
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
/// arch.
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
>
struct GemmFpAIntB
{
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Element;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Mma::LayoutC;
using ElementScale = ElementC;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformA;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
/// Parameters structure
struct Arguments
{
GemmUniversalMode mode = GemmUniversalMode::kGemm;
cutlass::gemm::GemmCoord problem_size;
int group_size;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::TensorRef ref_B;
typename Mma::IteratorScale::TensorRef ref_scale;
typename Mma::IteratorScale::TensorRef ref_zero;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
// Control serial split-k
int batch_count;
typename EpilogueOutputOp::Params output_op;
// For gather+scatter operations
int const* gather_A_indices;
int const* gather_B_indices;
int const* scatter_D_indices;
// Included so we can use Gemm Universal
int batch_stride_D = 0;
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments() {}
CUTLASS_HOST_DEVICE
Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size,
typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B,
typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor,
typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(),
int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr,
int const* scatter_D_indices = nullptr)
: problem_size(problem_size)
, group_size(group_size)
, ref_A(ref_A)
, ref_B(ref_B)
, ref_scale(ref_scale)
, ref_zero(ref_zero)
, ref_C(ref_C)
, ref_D(ref_D)
, batch_count(serial_split_k_factor)
, output_op(output_op)
, gather_A_indices(gather_A_indices)
, gather_B_indices(gather_B_indices)
, scatter_D_indices(scatter_D_indices)
{
}
};
/// Parameters structure
struct Params
{
cutlass::gemm::GemmCoord problem_size;
int group_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::Params params_B;
typename Mma::IteratorB::TensorRef ref_B;
typename Mma::IteratorScale::Params params_scale;
typename Mma::IteratorScale::TensorRef ref_scale;
typename Mma::IteratorScale::TensorRef ref_zero;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
typename EpilogueOutputOp::Params output_op;
int* semaphore;
int gemm_k_size;
// For gather+scatter operations
int const* gather_A_indices;
int const* gather_B_indices;
int const* scatter_D_indices;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: swizzle_log_tile(0)
, semaphore(0)
, gemm_k_size(0)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size,
void* workspace = nullptr)
: problem_size(args.problem_size)
, group_size(args.group_size)
, grid_tiled_shape(grid_tiled_shape)
, swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape))
, params_A(args.ref_A.layout())
, ref_A(args.ref_A)
, params_B(args.ref_B.layout())
, ref_B(args.ref_B)
, params_scale(args.ref_scale.layout())
, ref_scale(args.ref_scale)
, ref_zero(args.ref_zero)
, params_C(args.ref_C.layout())
, ref_C(args.ref_C)
, params_D(args.ref_D.layout())
, ref_D(args.ref_D)
, output_op(args.output_op)
, semaphore(static_cast<int*>(workspace))
, gemm_k_size(gemm_k_size)
, gather_A_indices(args.gather_A_indices)
, gather_B_indices(args.gather_B_indices)
, scatter_D_indices(args.scatter_D_indices)
{
}
};
/// Shared memory storage structure
union SharedStorage
{
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
GemmFpAIntB() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(Arguments const& args)
{
static int const kAlignmentA
= (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ? 32
: (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB
= (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ? 32
: (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements;
static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Epilogue::OutputTileIterator::kElementsPerAccess;
if (!TensorRef_aligned(args.ref_A, kAlignmentA))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_B, kAlignmentB))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_scale, kAlignmentScale))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_zero, kAlignmentScale))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_C, kAlignmentC))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_D, kAlignmentC))
{
return Status::kErrorMisalignedOperand;
}
if (!args.ref_scale.good())
{
return Status::kErrorNotSupported;
}
if constexpr (hasZero(Mma::QuantOp))
{
if (!args.ref_zero.good())
{
return Status::kErrorNotSupported;
}
}
else
{
if (args.ref_zero.good())
{
return Status::kErrorNotSupported;
}
}
if constexpr (isFinegrained(Mma::QuantOp))
{
if (args.group_size != 64 && args.group_size != 128)
{
return Status::kErrorNotSupported;
}
}
return Status::kSuccess;
}
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
{
return 0;
}
// Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator
// has a different constructor signature than a regular cutlass iterator
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<isFinegrained(op), bool> = true>
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
typename IteratorScale::TensorCoord extent, int thread_id,
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
{
return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size);
}
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<!isFinegrained(op), bool> = true>
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
typename IteratorScale::TensorCoord extent, int thread_id,
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
{
return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset);
}
CUTLASS_DEVICE
void run_kernel_(Params const& params, SharedStorage& shared_storage)
{
using LayoutB = typename Mma::IteratorB::Layout;
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
{
return;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size,
};
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64;
typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0;
cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(),
{params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices);
typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(),
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B,
params.gather_B_indices);
typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1;
typename Mma::IteratorScale iterator_scale = initialize_scale<typename Mma::IteratorScale, Mma::QuantOp>(
params.params_scale, params.ref_scale.data(), params.ref_zero.data(),
{scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations > 0)
{
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
}
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(),
thread_idx, threadblock_offset, params.scatter_D_indices);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(),
thread_idx, threadblock_offset, params.scatter_D_indices);
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k())
{
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1)
{
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else
{
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
{
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
{
run_kernel_(params, shared_storage);
}
else
{
CUTLASS_NOT_IMPLEMENTED();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ == 890)
run_kernel<arch::Sm89>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 900)
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
#else
static_assert(
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
#endif
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cutlass/gemm/kernel/gemm_grouped_problem_visitor.h>
#include <cutlass/trace.h>
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh>
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
namespace fused_moe
{
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int MaxTileM_, int TileN_,
int TileK_, int Stages_, Activation_Type activation_type_>
struct Fused_Moe_Kernel_sm80
{
static constexpr int kMaxTileM = MaxTileM_;
static constexpr int kTileN = isGateActivation(activation_type_) ? TileN_ / 2 : TileN_;
static constexpr int kTileK = TileK_;
static constexpr int kStages = Stages_;
static constexpr Activation_Type activation_type = activation_type_;
using ElementInput = ElementInput_;
using ElementWeight = ElementWeight_;
using ElementOutput = ElementOutput_;
using BaseKernelTraits = Fused_Moe_Kernel_traits_sm80<ElementInput, ElementWeight, ElementOutput, kMaxTileM, kTileN,
kTileK, kStages, activation_type>;
using Routine_Arguments = Routine_Arguments<ElementInput, ElementWeight, ElementOutput>;
using Routine_Params = Routine_Params<ElementInput, ElementWeight, ElementOutput>;
using ProblemVisitor
= cutlass::gemm::kernel::MoeProblemVisitor<cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper<
cutlass::gemm::GemmShape<kMaxTileM, kTileN, kTileK>, false>,
cutlass::gemm::GemmShape<kMaxTileM, kTileN, kTileK>, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
BaseKernelTraits::kThreadCount, BaseKernelTraits::kThreadCount>;
struct Arguments
{
Routine_Arguments routine_args;
int problem_count{};
int threadblock_count{};
};
struct Params
{
Routine_Params routine_params;
int threadblock_count{};
typename ProblemVisitor::Params problem_visitor_param;
};
using BaseKernelTraits_m16 = Fused_Moe_Kernel_traits_sm80<ElementInput, ElementWeight, ElementOutput, 16, kTileN,
kTileK, kStages, activation_type>;
static constexpr bool use_m16 = TileK_ >= 64; // use tileshape m = 16 when original tileshape k >= 64
static constexpr int kSmemSize = use_m16
? (BaseKernelTraits::kSmemSize > BaseKernelTraits_m16::kSmemSize ? BaseKernelTraits::kSmemSize
: BaseKernelTraits_m16::kSmemSize)
: BaseKernelTraits::kSmemSize;
static constexpr int kThreadCount = BaseKernelTraits::kThreadCount;
static constexpr bool can_implement(int const avaliable_smem_size)
{
return BaseKernelTraits::can_implement(avaliable_smem_size);
}
static Params to_underlying_arguments(Arguments const& args)
{
return {
{args.routine_args.ptr_input, args.routine_args.ptr_fc1, args.routine_args.ptr_bias,
args.routine_args.ptr_output, args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n,
args.routine_args.gemm_k, args.routine_args.num_expert, args.routine_args.bias_is_broadcast},
args.threadblock_count,
{args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, args.routine_args.gemm_k,
args.problem_count, nullptr, 0}};
}
CUTE_DEVICE
void run_device(Params const& params)
{
#define ROUTINE_PATH(kTileM_size) \
{ \
constexpr int kTileM = use_m16 ? (kTileM_size) : ((kTileM_size) == 16 ? 32 : (kTileM_size)); \
using RoutineTraits = Fused_Moe_Kernel_routine_sm80<ElementInput, ElementWeight, ElementOutput, kTileM, \
kTileN, kTileK, kStages, activation_type>; \
RoutineTraits routine{}; \
int const block_m_idx = (block_m_idx_temp) *kMaxTileM / kTileM; \
routine.run_routine(params.routine_params, problem_index, block_m_idx, block_n_idx, gemm_m); \
}
typename ProblemVisitor::SharedStorage dummy_storage{};
ProblemVisitor problem_visitor(params.problem_visitor_param, dummy_storage, blockIdx.x);
while (problem_visitor.next_tile())
{
auto problem_size = problem_visitor.problem_size();
auto grid_size = problem_visitor.grid_shape(problem_size);
auto problem_index = problem_visitor.problem_index();
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
int const gemm_m = problem_size.m();
const int32_t block_m_idx_temp = cta_idx / grid_size.n();
const int32_t block_n_idx = cta_idx % grid_size.n();
int const residue_m = gemm_m - kMaxTileM * block_m_idx_temp;
if (residue_m > kMaxTileM / 2)
{
using RoutineTraits = Fused_Moe_Kernel_routine_sm80<ElementInput, ElementWeight, ElementOutput,
kMaxTileM, kTileN, kTileK, kStages, activation_type>;
RoutineTraits routine{};
routine.run_routine(params.routine_params, problem_index, block_m_idx_temp, block_n_idx, gemm_m);
}
else
{
if constexpr (kMaxTileM >= 128)
{
if (residue_m > 32)
{
ROUTINE_PATH(64);
}
else if (residue_m > 16)
{
ROUTINE_PATH(32);
}
else
{
// TODO: use cuda core gemm here
ROUTINE_PATH(16);
}
}
else if (kMaxTileM == 64)
{
if (residue_m > 16)
{
ROUTINE_PATH(32);
}
else
{
// TODO: use cuda core gemm here
ROUTINE_PATH(16);
}
}
else if (kMaxTileM == 32)
{
// TODO: use cuda core gemm here
ROUTINE_PATH(16);
}
else
{
// TODO: use cuda core gemm here
ROUTINE_PATH(16);
}
}
problem_visitor.advance(gridDim.x);
}
#undef ROUTINE_PATH
}
};
template <typename GemmType>
__global__ void run_global(__grid_constant__ typename GemmType::Params const params)
{
GemmType gemm;
gemm.run_device(params);
}
/// Computes the maximum number of active blocks per multiprocessor
template <typename GemmType>
static int fused_gemm_maximum_active_blocks(int smem_capacity = -1)
{
CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()");
constexpr int smem_size = GemmType::kSmemSize;
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
cudaError_t result;
if (smem_size > (48 << 10))
{
result = cudaFuncSetAttribute(run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
return -1;
}
}
int max_active_blocks = -1;
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, run_global<GemmType>, GemmType::kThreadCount, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
} // namespace fused_moe
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
namespace fused_moe
{
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
int Stages_, Activation_Type activation_type_, typename Enable = void>
struct Fused_Moe_Kernel_routine_sm80;
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
int Stages_, Activation_Type activation_type_>
struct Fused_Moe_Kernel_routine_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_, Stages_,
activation_type_, std::enable_if_t<isGateActivation(activation_type_)>>
{
using KT = Fused_Moe_Kernel_traits_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_,
Stages_, activation_type_>;
using Params = Routine_Params<ElementInput_, ElementWeight_, ElementOutput_>;
CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params)
{
using X = cute::Underscore;
int const M = gemm_m;
int const N1 = params.gemm_n;
int const K1 = params.gemm_k;
bool const bias_is_broadcast = params.bias_is_broadcast;
int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]);
typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1;
typename KT::ElementWeight const* ptr_fc1_gate_
= params.ptr_fc1 + (2 * problem_index + 1) * N1 * K1; // TODO: we only focus on gated activation..
typename KT::ElementWeight const* ptr_fc1_
= params.ptr_fc1 + 2 * problem_index * N1 * K1; // TODO: we only focus on gated activation..
typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr)
? nullptr
: (bias_is_broadcast ? params.ptr_bias + 2 * problem_index * N1 : params.ptr_bias + 2 * row_jump * N1);
typename KT::ElementInput const* ptr_bias_gate_ = (params.ptr_bias == nullptr)
? nullptr
: (bias_is_broadcast ? params.ptr_bias + (2 * problem_index + 1) * N1
: params.ptr_bias + (2 * row_jump + 1) * N1);
typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1;
cute::Tensor mInput_mk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_input_)),
cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mfc1_gate_nk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_gate_)),
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mfc1_nk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_)),
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mBias_mn = cute::make_tensor(
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_)), cute::make_shape(M, N1),
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2,
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
cute::Tensor mBias_gate_mn = cute::make_tensor(
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_gate_)), cute::make_shape(M, N1),
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2,
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
cute::Tensor mOutput_mn
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput*>(ptr_output_)),
cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{}));
cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
cute::Tensor gfc1_gate_nk = cute::local_tile(mfc1_gate_nk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
cute::Tensor gBias_gate_mn = cute::local_tile(mBias_gate_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
return cute::make_tuple(gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn);
}
// be careful, m_idx will change when use another tile shape..
CUTE_DEVICE void run_routine(
Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m)
{
extern __shared__ char smem_[];
typename KT::SharedStorage& shared_storage = *reinterpret_cast<typename KT::SharedStorage*>(smem_);
int const thread_idx = threadIdx.x;
bool const bias_is_broadcast = params.bias_is_broadcast;
// gmem tensor partition ..
auto [gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn]
= gmem_tensor_init(problem_index, gemm_m, params);
int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk);
auto const n_tile_count = cute::size<2>(gfc1_gate_nk);
// smem tensor ..
cute::Tensor sInput = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage)
cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()),
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
cute::Tensor sfc1_gate_weight
= cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_gate_weight.data()),
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
cute::Tensor sO = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N)
// (1) first step, get the fc1_res and fc1_gate
// (1.1) get partition for gmem -> smem
cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k)
cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
cute::Tensor gfc1g = gfc1_gate_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
typename KT::GmemTiledCopyA gmem_tiled_copy_A;
typename KT::GmemTiledCopyB gmem_tiled_copy_B;
auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx);
auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx);
cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k)
cute::Tensor tInputsInput = gmem_thr_copy_A.partition_D(sInput); // (ACPY,ACPY_M,ACPY_K,Stage)
cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k)
cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
cute::Tensor tfc1ggfc1g = gmem_thr_copy_B.partition_S(gfc1g); // (BCPY,BCPY_N,BCPY_K,k)
cute::Tensor tfc1gsfc1g = gmem_thr_copy_B.partition_D(sfc1_gate_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
cute::Tensor tInputpInput
= cute::make_tensor<bool>(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)),
cute::Stride<cute::_1, cute::_0>{});
// Construct identity layout for sInput
cute::Tensor cInput = make_identity_tensor(
make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Set predicates for m bounds
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<0>(tInputpInput); ++m)
{
tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m
}
// (1.2) prefetch gmem -> smem
cute::clear(tInputsInput); // we don't need to clear tfc1sfc1..
auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0
int k_tile_count = cute::size<2>(gInput);
CUTLASS_PRAGMA_UNROLL
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
{
if (k_tile_count <= 0)
{
cute::clear(tInputpInput);
}
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, k_pipe));
// use copy_if
cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
tInputsInput(cute::_, cute::_, cute::_, k_pipe));
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1sfc1(cute::_, cute::_, cute::_, k_pipe));
cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1gsfc1g(cute::_, cute::_, cute::_, k_pipe));
cute::cp_async_fence();
k_tile_count--;
if (k_tile_count > 0)
{
++k_tile_iter;
}
}
// (1.3) get partition for rf
typename KT::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K)
cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
cute::Tensor tOrfc1g = thr_mma.partition_fragment_B(sfc1_gate_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
cute::Tensor accum
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
cute::Tensor accum_gate
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
cute::clear(accum);
cute::clear(accum_gate);
// checkout the shape
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum_gate)); // MMA_M
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum_gate)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum_gate)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1g)); // MMA_K
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma));
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma));
// (1.4)retiling the smem and rf for copy..
auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage)
cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K
auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma);
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage)
cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K)
cute::Tensor tOsfc1g = smem_thr_copy_B.partition_S(sfc1_gate_weight); // (CPY,CPY_N,CPY_K,Stage)
cute::Tensor tOrfc1g_copy_view = smem_thr_copy_B.retile_D(tOrfc1g); // (CPY,CPY_N,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1g) == cute::size<1>(tOrfc1g_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1g) == cute::size<2>(tOrfc1g_copy_view)); // CPY_K
// (1.5) mainloop
// Current pipe index in smem to read from
int smem_pipe_read = 0;
// Current pipe index in smem to write to
int smem_pipe_write = KT::Stages - 1;
cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
cute::Tensor tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput);
// prefetch register pipeline
if constexpr (K_BLOCK_MAX > 1)
{
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
// Prefetch the first rmem from the first k-tile
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}),
tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{}));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}),
tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{}));
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, cute::Int<0>{}),
tOrfc1g_copy_view(cute::_, cute::_, cute::Int<0>{}));
}
// k loop for mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
// Copy gmem to smem before computing gemm on each k-pipe
if (k_block == 0)
{
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy_if(gmem_tiled_copy_A, tInputpInput,
tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1gsfc1g(cute::_, cute::_, cute::_, smem_pipe_write));
cute::cp_async_fence();
if (k_tile_count - 1 > 0)
{
++k_tile_iter;
}
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
smem_pipe_write = smem_pipe_read;
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block),
accum);
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
});
}
// load tail
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
[&](auto WaitIndex)
{
k_tile_count--;
using WaitIndex_t = decltype(WaitIndex);
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
if (k_block == 0)
{
// only update smem_pipe_read
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block),
tOrfc1(cute::_, cute::_, k_block), accum);
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
});
});
// mma tail
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
// Thread-level register gemm for k_block
cute::gemm(
tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum);
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
});
// if (cute::thread0()) {
// cute::print(accum_gate(0, 0, 0));
// printf("\n");
// }
// (2) add bias if it has..
if (params.ptr_bias != nullptr)
{
cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
cute::Tensor gBias_gate = gBias_gate_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
cute::Tensor tOgBias = thr_mma.partition_C(gBias);
cute::Tensor tOgBiasg = thr_mma.partition_C(gBias_gate);
for (int i = 0; i < cute::size(accum); i++)
{
accum(i) += tOgBias(i);
accum_gate(i) += tOgBiasg(i);
}
}
// (3) calculate swiglu
using ActivationFn = typename KT::ActivationFn;
ActivationFn fn{};
CUTLASS_PRAGMA_UNROLL
for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++)
{
accum(temp_iter) = fn(accum_gate(temp_iter)) * accum(temp_iter);
}
// (4) push all the result to smem
// (4.1) convert result from ElementAccum to ElementInput
cute::Tensor temp_accum = util_convert_type<KT::ElementOutput>(accum);
// if (cute::thread0()) {
// cute::print(temp_accum(0, 0, 0));
// printf("\n");
// }
// (4.2) retile rf and smem for copy back..
auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
// cute::clear(sO);
cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum);
cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO);
// (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..)
cute::copy(smem_tiled_copy_O, taccumrO, taccumsO);
__syncthreads();
// (4.4) sO -> rO -> gO
typename KT::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
// auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); //
// remember, for all the threads in the same col, they have the same idx for bias..
cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx);
// cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row..
auto tOsO = gmem_thr_copy_O.partition_S(sO);
auto tOgO = gmem_thr_copy_O.partition_D(gO);
// auto tOgBias = gmem_thr_copy_O.partition_D(gBias);
cute::Tensor cOutput = cute::make_identity_tensor(
cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{})));
cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput);
cute::Tensor tOrO = cute::make_tensor<KT::ElementOutput>(cute::shape(tOgO));
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<1>(tOgO); ++m)
{
if (cute::get<0>(tOcO(0, m, 0)) < residue_m)
{
cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_));
}
}
}
};
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
int Stages_, Activation_Type activation_type_>
struct Fused_Moe_Kernel_routine_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_, Stages_,
activation_type_, std::enable_if_t<!isGateActivation(activation_type_)>>
{
using KT = Fused_Moe_Kernel_traits_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_,
Stages_, activation_type_>;
using Params = Routine_Params<ElementInput_, ElementWeight_, ElementOutput_>;
CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params)
{
using X = cute::Underscore;
int const M = gemm_m;
int const N1 = params.gemm_n;
int const K1 = params.gemm_k;
bool const bias_is_broadcast = params.bias_is_broadcast;
int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]);
typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1;
typename KT::ElementWeight const* ptr_fc1_ = params.ptr_fc1 + problem_index * N1 * K1;
typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr)
? nullptr
: (bias_is_broadcast ? params.ptr_bias + problem_index * N1 : params.ptr_bias + row_jump * N1);
typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1;
cute::Tensor mInput_mk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_input_)),
cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mfc1_nk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_)),
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mBias_mn = cute::make_tensor(
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_)), cute::make_shape(M, N1),
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1,
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
cute::Tensor mOutput_mn
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput*>(ptr_output_)),
cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{}));
cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
return cute::make_tuple(gInput_mk, gfc1_nk, gBias_mn, gOutput_mn);
}
// be careful, m_idx will change when use another tile shape..
CUTE_DEVICE void run_routine(
Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m)
{
extern __shared__ char smem_[];
typename KT::SharedStorage& shared_storage = *reinterpret_cast<typename KT::SharedStorage*>(smem_);
int const thread_idx = threadIdx.x;
bool const bias_is_broadcast = params.bias_is_broadcast;
// gmem tensor partition ..
auto [gInput_mk, gfc1_nk, gBias_mn, gOutput_mn] = gmem_tensor_init(problem_index, gemm_m, params);
int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk);
auto const n_tile_count = cute::size<2>(gfc1_nk);
// smem tensor ..
cute::Tensor sInput = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage)
cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()),
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
cute::Tensor sO = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N)
// (1) first step, get the fc1_res and fc1_gate
// (1.1) get partition for gmem -> smem
cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k)
cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
typename KT::GmemTiledCopyA gmem_tiled_copy_A;
typename KT::GmemTiledCopyB gmem_tiled_copy_B;
auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx);
auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx);
cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k)
cute::Tensor tInputsInput = gmem_thr_copy_A.partition_S(sInput); // (ACPY,ACPY_M,ACPY_K,Stage)
cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k)
cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
cute::Tensor tInputpInput
= cute::make_tensor<bool>(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)),
cute::Stride<cute::_1, cute::_0>{});
// Construct identity layout for sInput
cute::Tensor cInput = make_identity_tensor(
make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Set predicates for m bounds
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<0>(tInputpInput); ++m)
{
tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m
}
// (1.2) prefetch gmem -> smem
cute::clear(tInputsInput); // we don't need to clear tfc1sfc1..
auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0
int k_tile_count = cute::size<2>(gInput);
CUTLASS_PRAGMA_UNROLL
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
{
if (k_tile_count <= 0)
{
cute::clear(tInputpInput);
}
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, k_pipe));
// use copy_if
cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
tInputsInput(cute::_, cute::_, cute::_, k_pipe));
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1sfc1(cute::_, cute::_, cute::_, k_pipe));
cute::cp_async_fence();
k_tile_count--;
if (k_tile_count > 0)
{
++k_tile_iter;
}
}
// (1.3) get partition for rf
typename KT::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K)
cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
cute::Tensor accum
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
cute::clear(accum);
// checkout the shape
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma));
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma));
// (1.4)retiling the smem and rf for copy..
auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage)
cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K
auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma);
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage)
cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K
// (1.5) mainloop
// Current pipe index in smem to read from
int smem_pipe_read = 0;
// Current pipe index in smem to write to
int smem_pipe_write = KT::Stages - 1;
cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput);
// prefetch register pipeline
if constexpr (K_BLOCK_MAX > 1)
{
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
// Prefetch the first rmem from the first k-tile
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}),
tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{}));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}),
tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{}));
}
// k loop for mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
// Copy gmem to smem before computing gemm on each k-pipe
if (k_block == 0)
{
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy_if(gmem_tiled_copy_A, tInputpInput,
tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write));
cute::cp_async_fence();
if (k_tile_count - 1 > 0)
{
++k_tile_iter;
}
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
smem_pipe_write = smem_pipe_read;
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block),
accum);
});
}
// load tail
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
[&](auto WaitIndex)
{
k_tile_count--;
using WaitIndex_t = decltype(WaitIndex);
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
if (k_block == 0)
{
// only update smem_pipe_read
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block),
tOrfc1(cute::_, cute::_, k_block), accum);
});
});
// mma tail
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
// Thread-level register gemm for k_block
cute::gemm(
tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum);
});
// if (cute::thread0()) {
// cute::print(accum_gate(0, 0, 0));
// printf("\n");
// }
// (2) add bias if it has..
if (params.ptr_bias != nullptr)
{
cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
cute::Tensor tOgBias = thr_mma.partition_C(gBias);
for (int i = 0; i < cute::size(accum); i++)
{
accum(i) += tOgBias(i);
}
}
// (3) calculate swiglu
using ActivationFn = typename KT::ActivationFn;
ActivationFn fn{};
CUTLASS_PRAGMA_UNROLL
for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++)
{
accum(temp_iter) = fn(accum(temp_iter));
}
// (4) push all the result to smem
// (4.1) convert result from ElementAccum to ElementInput
cute::Tensor temp_accum = util_convert_type<KT::ElementOutput>(accum);
// if (cute::thread0()) {
// cute::print(temp_accum(0, 0, 0));
// printf("\n");
// }
// (4.2) retile rf and smem for copy back..
auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
// cute::clear(sO);
cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum);
cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO);
// (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..)
cute::copy(smem_tiled_copy_O, taccumrO, taccumsO);
__syncthreads();
// (4.4) sO -> rO -> gO
typename KT::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
// auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); //
cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx);
auto tOsO = gmem_thr_copy_O.partition_S(sO);
auto tOgO = gmem_thr_copy_O.partition_D(gO);
cute::Tensor cOutput = cute::make_identity_tensor(
cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{})));
cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput);
cute::Tensor tOrO = cute::make_tensor<KT::ElementOutput>(cute::shape(tOgO));
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<1>(tOgO); ++m)
{
if (cute::get<0>(tOcO(0, m, 0)) < residue_m)
{
cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_));
}
}
}
};
} // namespace fused_moe
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass_extensions/epilogue_helpers.h>
#include <cutlass_extensions/gemm/kernel/moe_cute_util.cuh>
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
namespace fused_moe
{
template <typename ElementInput, typename ElementWeight, typename ElementOutput>
struct Routine_Arguments
{
ElementInput* ptr_input{};
ElementWeight* ptr_fc1{};
ElementInput* ptr_bias{};
ElementOutput* ptr_output{};
int64_t const* total_tokens_including_expert{};
int gemm_n{};
int gemm_k{};
int num_expert{};
bool bias_is_broadcast{};
};
template <typename ElementInput, typename ElementWeight, typename ElementOutput>
struct Routine_Params
{
ElementInput* ptr_input{};
ElementWeight* ptr_fc1{};
ElementInput* ptr_bias{};
ElementOutput* ptr_output{};
int64_t const* total_tokens_including_expert{};
int gemm_n{};
int gemm_k{};
int num_expert{};
bool bias_is_broadcast{};
};
enum class Activation_Type
{
Gelu = 0,
Relu,
Silu,
Swiglu,
Geglu,
Identity,
InvalidType
};
constexpr bool isGateActivation(Activation_Type const& activation_type)
{
return activation_type == Activation_Type::Swiglu || activation_type == Activation_Type::Geglu;
}
template <typename CutlassExtensionEpilogueTag>
constexpr Activation_Type EpilogueRouting(bool /*is_gate*/)
{
return Activation_Type::InvalidType;
}
template <>
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefault>(bool /*is_gate*/)
{
return Activation_Type::Identity;
}
template <>
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultReLU>(bool /*is_gate*/)
{
return Activation_Type::Relu;
}
template <>
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultSilu>(bool is_gate)
{
return is_gate ? Activation_Type::Swiglu : Activation_Type::Silu;
}
template <>
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultFtGelu>(bool is_gate)
{
return is_gate ? Activation_Type::Geglu : Activation_Type::Gelu;
}
/* fusing all three kernels has many limitations. This is the simpler version. Just fuse first two kernels..*/
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
int Stages_, Activation_Type activation_type>
struct Fused_Moe_Kernel_traits_sm80
{
using ElementInput = ElementInput_;
using ElementWeight = ElementWeight_;
using ElementAccum = float;
using ElementOutput = ElementOutput_;
using index_t = uint32_t;
static_assert(TileM_ % 16 == 0);
static_assert(TileN_ % 32 == 0);
static_assert(TileK_ % 32 == 0);
static constexpr int Stages = Stages_;
static constexpr int kTileM = TileM_;
static constexpr int kTileN = TileN_;
static constexpr int kTileK = (kTileM > 16) ? (TileK_) : (TileK_ >= 64 ? TileK_ : 64);
// tile shape
using TileShape = cute::Shape<cute::Int<kTileM>, cute::Int<kTileN>, cute::Int<kTileK>>;
static constexpr int kWarpsCount = 4;
static constexpr int kThreadCount = kWarpsCount * 32;
// MMA atom arch and layout
using MMA_Atom_Arch = std::conditional_t<std::is_same_v<ElementInput, cutlass::half_t>,
cute::MMA_Atom<cute::SM80_16x8x16_F32F16F16F32_TN>, cute::MMA_Atom<cute::SM80_16x8x16_F32BF16BF16F32_TN>>;
// using ValLayoutMNK = cute::Layout<cute::Shape<cute::_1, cute::_2, cute::_1>>;
using ThreadLayoutMNK
= std::conditional_t<kTileM == 16, cute::Layout<cute::Shape<cute::_1, cute::Int<kWarpsCount / 1>, cute::_1>>,
cute::Layout<cute::Shape<cute::_2, cute::Int<kWarpsCount / 2>, cute::_1>>>;
using ValLayoutMNK = std::conditional_t<kTileM == 16, cute::Tile<cute::_16, cute::_64, cute::_16>,
cute::Tile<cute::_32, cute::_32, cute::_16>>;
using TiledMma = cute::TiledMMA<MMA_Atom_Arch, ThreadLayoutMNK,
ValLayoutMNK>; // 32x32x16 or 16x64x16 MMA for LDSM if kWarp = 4
static constexpr int kAlignment = 8;
static constexpr int kBlcokKSmem = (kTileM == 16) ? 64 : 32;
// A memory copy operand
using DefaultOperandA
= DefaultGemm_TensorOpSm80_OperandA<ElementInput, cutlass::layout::RowMajor, kAlignment, kBlcokKSmem>;
using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom;
using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom;
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
// B memory copy operand
using DefaultOperandB
= DefaultGemm_TensorOpSm80_OperandB<ElementWeight, cutlass::layout::ColumnMajor, kAlignment, kBlcokKSmem>;
using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom;
using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom;
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
// Output memory copy operand
using SmemLayoutAtomO = SmemLayoutAtomA;
using SmemCopyAtomO = cute::Copy_Atom<cute::DefaultCopy, ElementOutput>;
static constexpr int kGmemElementPerLoad = sizeof(cute::uint128_t) / sizeof(ElementOutput);
static constexpr int kGmemTrheadsPerRow = kBlcokKSmem / kGmemElementPerLoad;
using GmemLayoutAtomO
= cute::Layout<cute::Shape<cute::Int<kThreadCount / kGmemTrheadsPerRow>, cute::Int<kGmemTrheadsPerRow>>,
cute::Stride<cute::Int<kGmemTrheadsPerRow>, cute::_1>>;
using GmemTiledCopyO = decltype(cute::make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, ElementOutput>{},
GmemLayoutAtomO{}, cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
static_assert(cute::rank(SmemLayoutAtomA{}) == 2);
static_assert(cute::size<0>(TileShape{}) % cute::size<0>(SmemLayoutAtomA{}) == 0); // M
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomA{}) == 0); // K
static_assert(cute::rank(SmemLayoutAtomB{}) == 2);
static_assert(cute::size<1>(TileShape{}) % cute::size<0>(SmemLayoutAtomB{}) == 0); // N
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomB{}) == 0); // K
using SmemLayoutA = decltype(cute::tile_to_shape(SmemLayoutAtomA{},
cute::make_shape(
cute::shape<0>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_M, BLK_K, Stages
using SmemLayoutB = decltype(cute::tile_to_shape(SmemLayoutAtomB{},
cute::make_shape(
cute::shape<1>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_N, BLK_K, Stages
using SmemLayoutO = decltype(cute::tile_to_shape(
SmemLayoutAtomO{}, cute::make_shape(cute::shape<0>(TileShape{}), cute::shape<1>(TileShape{})))); // BLK_M, BLK_N
// we need at least 2 stages..
static_assert(Stages >= 2);
struct SharedStorageNormal : cute::aligned_struct<128>
{
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutA>> smem_input;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_weight;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct SharedStorageGate : cute::aligned_struct<128>
{
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutA>> smem_input;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_gate_weight;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_weight;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutO>> smem_o;
};
using SharedStorage = std::conditional_t<isGateActivation(activation_type), SharedStorageGate, SharedStorageNormal>;
using ActivationFn = std::conditional_t<activation_type == Activation_Type::Gelu
|| activation_type == Activation_Type::Geglu,
cutlass::epilogue::thread::GELU<float>,
std::conditional_t<activation_type == Activation_Type::Relu, cutlass::epilogue::thread::ReLU<float>,
std::conditional_t<activation_type == Activation_Type::Silu || activation_type == Activation_Type::Swiglu,
cutlass::epilogue::thread::SiLu<float>, cutlass::epilogue::thread::Identity<float>>>>;
static constexpr int kSmemSize = static_cast<int>(sizeof(SharedStorage));
static constexpr bool can_implement(int const avaliable_smem_size)
{
return avaliable_smem_size > kSmemSize;
}
// #endif
};
} // namespace fused_moe
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief Scheduler for grouped GEMM
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
#include "cutlass/matrix_coord.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_, int PrefetchTileCount, int ThreadCount,
bool Transposed = false>
struct GemmMoeProblemVisitor
: public MoeProblemVisitor<detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>, ThreadblockShape,
GroupScheduleMode_, PrefetchTileCount, ThreadCount>
{
static bool const kTransposed = Transposed;
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
using Base
= MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
using Params = typename Base::Params;
using SharedStorage = typename Base::SharedStorage;
//
// Methods
//
CUTLASS_DEVICE
GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
: Base(params_, shared_storage_, block_idx)
{
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel
{
////////////////////////////////////////////////////////////////////////////////
/*
* Stateless universal device GEMM kernel type that treats GEMM as
* a composition of a collective mainloop and a collective epilogue.
*
* Supports both the 2.x and 3.x APIs based on whether the first type is
* a cute::tuple<> or not.
* 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h
* 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp
*
* In the following declaration, the name preceding the 'Or' refers to
* 3.x API type argument order, and the name succeeding the 'Or' refers to
* 2.x API type argument order. Template arguments without two names
* belong to the 3.x API only.
**/
template <class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l)
class CollectiveMainloopOrEpilogue_, class CollectiveEpilogueOrThreadblockSwizzle_, class TileScheduler_ = void,
class Enable = void>
class GemmUniversalGated;
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel
////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp"
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp"
////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief GEMM kernel to support the epilogue visitor model
for customized softmax partial reduction epilogue fusion.
This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once
its usage has been stabilized. For now, it is included in this example to demonstrate
some basic output fusion options.
original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h"
namespace tk = tensorrt_llm::common;
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmWithEpilogueVisitor
{
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueVisitor = typename Epilogue::Visitor;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using TensorRefA = TensorRef<ElementA, LayoutA>;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using TensorRefB = TensorRef<ElementB, LayoutB>;
using ElementCompute = typename EpilogueVisitor::ElementCompute;
using LayoutAlphaCol = cutlass::layout::RowMajor;
using LayoutAlphaRow = cutlass::layout::ColumnMajor;
using TensorRefAlphaCol = TensorRef<ElementCompute, LayoutAlphaCol>;
using TensorRefAlphaRow = TensorRef<ElementCompute, LayoutAlphaRow>;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename Epilogue::Layout;
using TensorRefC = TensorRef<ElementC, LayoutC>;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
using EpilogueOutputOp =
typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment
= const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
//
// Structures
//
/// Argument structure
struct Arguments
{
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
TensorRefA ref_A;
TensorRefB ref_B;
tk::QuantMode quant_option;
TensorRefAlphaCol ref_alpha_col;
TensorRefAlphaRow ref_alpha_row;
TensorRefC ref_C;
TensorRefC ref_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_D;
typename EpilogueVisitor::Arguments epilogue_visitor;
//
// Methods
//
Arguments()
: mode(GemmUniversalMode::kGemm)
, batch_count(1)
{
}
/// constructs an arguments structure
Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_,
TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_,
TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_,
int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_)
: mode(mode_)
, problem_size(problem_size_)
, batch_count(batch_count_)
, ref_A(ref_A_)
, ref_B(ref_B_)
, quant_option(quant_option_)
, ref_alpha_col(ref_alpha_col_)
, ref_alpha_row(ref_alpha_row_)
, ref_C(ref_C_)
, ref_D(ref_D_)
, batch_stride_A(batch_stride_A_)
, batch_stride_B(batch_stride_B_)
, batch_stride_D(0)
, epilogue_visitor(epilogue_visitor_)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params
{
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col;
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row;
typename EpilogueVisitor::OutputTileIterator::Params params_C;
typename EpilogueVisitor::OutputTileIterator::Params params_D;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void* ptr_A;
void* ptr_B;
tk::QuantMode quant_option;
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col;
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row;
ElementC* ptr_C;
ElementC* ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
typename EpilogueVisitor::Params epilogue_visitor;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: swizzle_log_tile(0)
, params_A(0)
, params_B(0)
, params_alpha_col(0)
, params_C(0)
, params_D(0)
, batch_count(0)
, gemm_k_size(0)
, mode(cutlass::gemm::GemmUniversalMode::kGemm)
, ptr_A(nullptr)
, ptr_B(nullptr)
, ptr_alpha_col(nullptr)
, ptr_alpha_row(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, batch_stride_A(0)
, batch_stride_B(0)
{
}
Params(
Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_)
: problem_size(args.problem_size)
, swizzle_log_tile(0)
, params_A(args.ref_A.layout())
, params_B(args.ref_B.layout())
, params_alpha_col(args.ref_alpha_col.layout())
, params_alpha_row(args.ref_alpha_col.layout())
, params_C(args.ref_C.layout())
, params_D(args.ref_D.layout())
, mode(args.mode)
, batch_count(args.batch_count)
, gemm_k_size(args.problem_size.k())
, ptr_A(args.ref_A.data())
, ptr_B(args.ref_B.data())
, quant_option(args.quant_option)
, ptr_alpha_col(args.ref_alpha_col.data())
, ptr_alpha_row(args.ref_alpha_row.data())
, ptr_C(args.ref_C.data())
, ptr_D(args.ref_D.data())
, batch_stride_A(args.batch_stride_A)
, batch_stride_B(args.batch_stride_B)
, epilogue_visitor(args.epilogue_visitor)
{
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
{
int const kAlignK
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size)
{
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
}
};
/// Shared memory storage structure
union SharedStorage
{
typename Mma::SharedStorage main_loop;
struct
{
typename Epilogue::SharedStorage epilogue;
typename EpilogueVisitor::SharedStorage visitor;
} epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmWithEpilogueVisitor() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
{
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (platform::is_same<LayoutA, layout::RowMajor>::value)
{
isAMisaligned = problem_size.k() % kAlignmentA;
}
else if (platform::is_same<LayoutA, layout::ColumnMajor>::value)
{
isAMisaligned = problem_size.m() % kAlignmentA;
}
else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value)
{
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value)
{
isBMisaligned = problem_size.n() % kAlignmentB;
}
else if (platform::is_same<LayoutB, layout::ColumnMajor>::value)
{
isBMisaligned = problem_size.k() % kAlignmentB;
}
else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value)
{
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value)
{
isCMisaligned = problem_size.n() % kAlignmentC;
}
else if (platform::is_same<LayoutC, layout::ColumnMajor>::value)
{
isCMisaligned = problem_size.m() % kAlignmentC;
}
else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value)
{
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned)
{
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned)
{
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned)
{
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
static Status can_implement(Arguments const& args)
{
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
{
return 0;
}
#define SPLIT_K_ENABLED 1
/// Executes one GEMM
CUTLASS_DEVICE
void run_kernel_(Params const& params, SharedStorage& shared_storage)
{
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
{
return;
}
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA* ptr_A = static_cast<ElementA*>(params.ptr_A);
ElementB* ptr_B = static_cast<ElementB*>(params.ptr_B);
#if SPLIT_K_ENABLED
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel)
{
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k())
{
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
}
else if (params.mode == GemmUniversalMode::kBatched)
{
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
}
else if (params.mode == GemmUniversalMode::kArray)
{
ptr_A = static_cast<ElementA* const*>(params.ptr_A)[threadblock_tile_offset.k()];
ptr_B = static_cast<ElementB* const*>(params.ptr_B)[threadblock_tile_offset.k()];
}
#endif
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
//
// Construct the epilogue visitor
//
EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor,
params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C,
params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C,
params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m());
if (params.mode == GemmUniversalMode::kGemm)
{
// Indicate which position in a serial reduction the output operator is currently updating
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray)
{
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
}
// Construct the epilogue
Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(epilogue_visitor, accumulators);
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
{
if constexpr (platform::is_same<ArchTag, CompilationArch>::value)
{
run_kernel_(params, shared_storage);
}
else
{
CUTLASS_NOT_IMPLEMENTED();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750)
run_kernel<arch::Sm72>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 900)
// TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels.
run_kernel<arch::Sm80>(params, shared_storage);
#else
static_assert(
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
#endif
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
to be consumed by CUTLASS.
Note that for int4, ThreadBlockK MUST be 64.
*/
#pragma once
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/platform/platform.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
namespace cutlass
{
namespace gemm
{
namespace kernel
{
template <typename TypeA, typename TypeB, typename Arch, typename Enable = void>
struct LayoutDetailsB
{
};
// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
// TODO - Switch this to column major for weights since gemms should be more performant.
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA>
struct LayoutDetailsB<TypeA, cutlass::float_e4m3_t, arch::Sm89>
{
static constexpr int ThreadblockK = 64;
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
public:
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<cutlass::float_e4m3_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
// for fast accumulation
// using Operator = cutlass::arch::OpMultiplyAddFastAccum;
};
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
// which signals that we want to dequantize after loading from smem.
template <typename TypeA, typename Arch>
struct LayoutDetailsB < TypeA,
uint8_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
public:
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint8_t>::value;
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB < TypeA,
uint4b_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
public:
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint4b_t>::value;
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cute/algorithm/copy.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/layout/layout.h>
#include <cutlass/numeric_conversion.h>
template <typename Element, typename Layout, int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA;
template <typename Element, typename Layout, int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB;
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 64>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 64>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
/// Operand A - Column-major (M-major)
template <int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, 8, SizeK>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::half_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
};
template <int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, 8, SizeK>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::bfloat16_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
};
// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands
// Operand B - Column-Major (K-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
{
};
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
{
};
// Operand B - Row-Major (N-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{
};
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{
};
//
// F16: 128-by-128-by-32 (small k-block)
//
/// Operand A - Row-major (K-Major)
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 32>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 32>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
template <typename To_type, typename Engine, typename Layout>
CUTE_DEVICE auto util_convert_type(cute::Tensor<Engine, Layout> const& tensor)
{
using From_type = typename Engine::value_type;
constexpr int numel = decltype(cute::size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// HACK: this requires tensor to be "contiguous"
auto frag = convert_op(*reinterpret_cast<cutlass::Array<From_type, numel> const*>(tensor.data()));
return cute::make_tensor(cute::make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template <typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
CUTE_DEVICE void util_copy(
TiledCopy const& tiled_copy, cute::Tensor<Engine0, Layout0> const& S, cute::Tensor<Engine1, Layout1>& D)
{
CUTE_STATIC_ASSERT_V(cute::rank(S) == cute::Int<3>{});
CUTE_STATIC_ASSERT_V(cute::rank(D) == cute::Int<3>{});
CUTE_STATIC_ASSERT_V(cute::size<0>(S) == cute::size<0>(D));
CUTE_STATIC_ASSERT_V(cute::size<1>(S) == cute::size<1>(D));
CUTE_STATIC_ASSERT_V(cute::size<2>(S) == cute::size<2>(D));
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<1>(S); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < cute::size<2>(S); ++k)
{
cute::copy(tiled_copy, S(cute::_, m, k), D(cute::_, m, k));
}
}
}
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms.
// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global.
template <typename...>
using void_t = void;
template <typename Mma, typename = void>
struct use_dq_gemm : platform::false_type
{
};
template <typename Mma>
struct use_dq_gemm<Mma, void_t<typename Mma::IteratorScale>> : platform::true_type
{
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
/// arch.
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform
>
struct MoeFCGemm
{
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static bool const kTransposed = false;
// Optional transpose
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
kTransposed>;
// Public-facing type definitions related to operand element type, layout, and complex conjugate
// operation. Must interact with the 'kTransposed' notion.
static_assert(!kTransposed, "Transpose problem not supported");
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename MapArguments::LayoutC;
using ElementScale = ElementC;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = MapArguments::kAlignmentA;
static int const kAlignmentB = MapArguments::kAlignmentB;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor
= GemmMoeProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
//
// Structures
//
/// Argument structure
struct Arguments
{
//
// Data members
//
int problem_count;
int threadblock_count;
int group_size;
typename EpilogueOutputOp::Params output_op;
ElementA* ptr_A;
ElementB* ptr_B;
ElementScale* weight_scales;
ElementC* ptr_C;
ElementC* ptr_D;
bool C_is_broadcast;
int64_t const* total_tokens_including_expert;
int64_t gemm_n;
int64_t gemm_k;
// Only used by device-level operator
GemmCoord* host_problem_sizes;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments()
: problem_count(0)
, threadblock_count(0)
, ptr_A(nullptr)
, ptr_B(nullptr)
, weight_scales(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, total_tokens_including_expert(nullptr)
, gemm_n(0)
, gemm_k(0)
, host_problem_sizes(nullptr)
, C_is_broadcast{true}
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op,
ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C,
bool C_is_broadcast, ElementC* ptr_D, int64_t const* total_tokens_including_expert, int64_t gemm_n,
int64_t gemm_k, GemmCoord* host_problem_sizes = nullptr)
: problem_count(problem_count)
, threadblock_count(threadblock_count)
, group_size(group_size)
, output_op(output_op)
, ptr_A(const_cast<ElementA*>(ptr_A))
, ptr_B(const_cast<ElementB*>(ptr_B))
, weight_scales(const_cast<ElementScale*>(weight_scales))
, ptr_C(const_cast<ElementC*>(ptr_C))
, C_is_broadcast{C_is_broadcast}
, ptr_D(ptr_D)
, total_tokens_including_expert(total_tokens_including_expert)
, gemm_n(gemm_n)
, gemm_k(gemm_k)
, host_problem_sizes(nullptr)
{
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
{
assert(weight_scales);
}
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params
{
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
int group_size;
bool C_is_broadcast;
typename EpilogueOutputOp::Params output_op;
ElementA* ptr_A;
ElementB* ptr_B;
ElementScale* weight_scales;
ElementC* ptr_C;
ElementC* ptr_D;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: ptr_A(nullptr)
, ptr_B(nullptr)
, weight_scales(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, C_is_broadcast(true)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
: problem_visitor(
args.total_tokens_including_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count)
, threadblock_count(args.threadblock_count)
, group_size(args.group_size)
, output_op(args.output_op)
, ptr_A(args.ptr_A)
, ptr_B(args.ptr_B)
, weight_scales(args.weight_scales)
, ptr_C(args.ptr_C)
, ptr_D(args.ptr_D)
, C_is_broadcast(args.C_is_broadcast)
{
}
CUTLASS_HOST_DEVICE
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
{
problem_visitor = typename ProblemVisitor::Params(args.total_tokens_including_expert, args.gemm_n,
args.gemm_k, args.problem_count, workspace, tile_count);
threadblock_count = args.threadblock_count;
output_op = args.output_op;
ptr_A = args.ptr_A;
ptr_B = args.ptr_B;
weight_scales = args.weight_scales;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
C_is_broadcast = args.C_is_broadcast;
}
};
/// Shared memory storage structure
union SharedStorage
{
typename ProblemVisitor::SharedStorage problem_visitor;
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
MoeFCGemm() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
{
return Status::kSuccess;
}
static Status can_implement(Arguments const& args)
{
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
{
if (args.weight_scales == nullptr)
{
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t");
return Status::kInvalid;
}
}
else if (args.weight_scales != nullptr)
{
CUTLASS_TRACE_HOST(
"MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t");
return Status::kInvalid;
}
else if (args.group_size != args.gemm_k)
{
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)");
return Status::kInvalid;
}
// Handle the case the input is too short
else if (args.gemm_n < Mma::IteratorB::AccessType::kElements)
{
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment");
return Status::kInvalid;
}
return Status::kSuccess;
}
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
{
return 0;
}
CUTLASS_DEVICE
void run_kernel_(Params const& params, SharedStorage& shared_storage)
{
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
//
// Problem visitor.
//
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
const int64_t gemm_k = params.problem_visitor.gemm_k;
const int64_t gemm_n = params.problem_visitor.gemm_n;
int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits<ElementB>::value;
// Outer 'persistent' loop to iterate over tiles
int loop = 0;
while (problem_visitor.next_tile())
{
loop++;
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
cutlass::gemm::GemmCoord threadblock_offset(
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0);
// Load element pointers. Exchange pointers and strides if working on the transpose
const int64_t rows_to_jump
= problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
ElementA* ptr_A = reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
typename LayoutA::LongIndex ldm_A = gemm_k;
char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix;
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
typename LayoutB::LongIndex ldm_B
= platform::is_same<layout::RowMajor, LayoutB>::value ? gemm_n : gemm_k * kInterleave;
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B,
{problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
auto CreateMMA = [&]()
{
if constexpr (use_dq_gemm<Mma>::value)
return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
else
return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
};
Mma mma = CreateMMA();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n();
if constexpr (use_dq_gemm<Mma>::value)
{
const MatrixCoord scale_extent = {1, problem_size.n()};
typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()),
weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale);
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
}
else
{
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
}
//
// Epilogue
//
ElementC* ptr_C = reinterpret_cast<ElementC*>(params.ptr_C)
+ (params.C_is_broadcast ? problem_idx : rows_to_jump) * gemm_n;
ElementC* ptr_D = reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
// lora need to set as layout_C(gemm_n)
LayoutC layout_C = params.C_is_broadcast ? LayoutC(0) : LayoutC(gemm_n);
LayoutC layout_D(gemm_n);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn());
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn());
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
if constexpr (platform::is_same<EpilogueOutputOp,
cutlass::epilogue::thread::LinearCombination<typename EpilogueOutputOp::ElementOutput,
EpilogueOutputOp::kCount, typename EpilogueOutputOp::ElementAccumulator,
typename EpilogueOutputOp::ElementCompute, EpilogueOutputOp::kScale,
EpilogueOutputOp::kRound>>::value)
{
EpilogueOutputOp output_op(params.output_op, problem_idx);
epilogue(output_op, iterator_D, accumulators, iterator_C);
}
else
{
EpilogueOutputOp output_op(params.output_op);
epilogue(output_op, iterator_D, accumulators, iterator_C);
}
// Next tile
problem_visitor.advance(gridDim.x);
}
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
{
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
{
run_kernel_(params, shared_storage);
}
else
{
CUTLASS_NOT_IMPLEMENTED();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 890) && (__CUDA_ARCH__ < 900)
constexpr bool isFp8 = platform::is_same<ElementA, cutlass::float_e4m3_t>::value
|| platform::is_same<ElementA, cutlass::float_e5m2_t>::value;
if constexpr (isFp8)
{
run_kernel<arch::Sm89>(params, shared_storage);
}
else
{ // reuse sm80 kernel for other types, align with dispatchToArch
run_kernel<arch::Sm80>(params, shared_storage);
}
#elif (__CUDA_ARCH__ >= 900)
run_kernel<arch::Sm80>(params, shared_storage);
#else
static_assert(
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
#endif
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief Base scheduler for grouped problems, using MoE
*/
#pragma once
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ProblemSizeHelper, typename ThreadblockShape_>
struct BaseMoeProblemVisitor
{
using ThreadblockShape = ThreadblockShape_;
struct ProblemInfo
{
static int32_t const kNoPrefetchEntry = -1;
int32_t problem_idx;
int32_t problem_start;
CUTLASS_DEVICE
ProblemInfo()
: problem_idx(kNoPrefetchEntry)
, problem_start(kNoPrefetchEntry)
{
}
CUTLASS_DEVICE
ProblemInfo(int32_t problem_idx_, int32_t problem_start_)
: problem_idx(problem_idx_)
, problem_start(problem_start_)
{
}
};
struct Params
{
int64_t const* last_row_for_problem;
int64_t gemm_n;
int64_t gemm_k;
int32_t problem_count;
void const* workspace;
int32_t tile_count;
//
// Methods
//
/// Ctor
CUTLASS_HOST_DEVICE
Params()
: last_row_for_problem(nullptr)
, gemm_n(0)
, gemm_k(0)
, problem_count(0)
, workspace(nullptr)
, tile_count(0)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count,
void const* workspace = nullptr, int32_t tile_count = 0)
: last_row_for_problem(last_row_for_problem)
, gemm_n(gemm_n)
, gemm_k(gemm_k)
, problem_count(problem_count)
, workspace(workspace)
, tile_count(tile_count)
{
}
};
Params const& params;
int32_t tile_idx;
int32_t problem_tile_start;
int32_t problem_idx;
//
// Methods
//
CUTLASS_DEVICE
BaseMoeProblemVisitor(Params const& params_, int32_t block_idx)
: params(params_)
, tile_idx(block_idx)
, problem_tile_start(0)
, problem_idx(0)
{
}
/// Get the grid shape
CUTLASS_HOST_DEVICE
static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem)
{
return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1);
}
/// Gets the global tile index
CUTLASS_HOST_DEVICE
int32_t tile_index() const
{
return tile_idx;
}
/// Gets the index of the problem
CUTLASS_HOST_DEVICE
int32_t problem_index() const
{
return problem_idx;
}
CUTLASS_HOST_DEVICE
int32_t threadblock_idx() const
{
return tile_idx - problem_tile_start;
}
CUTLASS_DEVICE
void advance(int32_t grid_size)
{
tile_idx += grid_size;
}
CUTLASS_HOST_DEVICE
static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem)
{
ProblemSizeHelper::possibly_transpose_problem(problem);
}
/// Returns the problem size for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size() const
{
return problem_size(problem_idx);
}
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size(int idx) const
{
const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1];
const int64_t current_problem_row = params.last_row_for_problem[idx];
const int64_t gemm_m = current_problem_row - prev_problem_row;
GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k));
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
CUTLASS_HOST_DEVICE
static int32_t tile_count(cutlass::gemm::GemmCoord const& grid)
{
return ProblemSizeHelper::tile_count(grid);
}
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count)
{
int32_t total_tiles = 0;
for (int32_t i = 0; i < problem_count; ++i)
{
auto problem = host_problem_sizes_ptr[i];
possibly_transpose_problem(problem);
auto grid = grid_shape(problem);
total_tiles += tile_count(grid);
}
return total_tiles;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename ProblemSizeHelper, typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_,
int PrefetchTileCount, int ThreadCount>
struct MoeProblemVisitor;
/////////////////////////////////////////////////////////////////////////////////////////////////
// ProblemVisitor that performs all scheduling on device
//
template <typename ProblemSizeHelper, typename ThreadblockShape, int PrefetchTileCount, int ThreadCount>
struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode::kDeviceOnly, PrefetchTileCount,
ThreadCount> : public BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>
{
using Base = BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>;
using Params = typename Base::Params;
static int const kThreadCount = ThreadCount;
static bool const kRequiresPrecomputation = false;
static int const kThreadsPerWarp = 32;
struct SharedStorage
{
};
// Final tile of the problem loaded by this thread. Each thread will hold
// a separate value.
int32_t problem_ending_tile;
SharedStorage& shared_storage;
//
// Methods
//
CUTLASS_DEVICE
MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
: Base(params_, block_idx)
, problem_ending_tile(0)
, shared_storage(shared_storage_)
{
this->problem_idx = -1 * kThreadsPerWarp;
this->problem_tile_start = 0;
}
CUTLASS_DEVICE
bool next_tile()
{
// Check whether the tile to compute is within the range of the current problem.
int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp);
if (this->tile_idx < problem_tile_end)
{
return true;
}
// Check whether the tile to compute is within the current group of problems fetched by the warp.
// The last tile for this group is the final tile of the problem held by the final thread in the warp.
int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
// Keep the starting problem for this group in `problem_idx`. This is done to reduce
// register pressure. The starting problem for this group is simply the first problem
// in the group most recently fetched by the warp.
int32_t& group_problem_start = this->problem_idx;
group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp;
// Keep the starting tile for this group in `problem_tile_start`. This is done to reduce
// register pressure.
int32_t& group_tile_start = this->problem_tile_start;
// Each thread in the warp processes a separate problem to advance until
// reaching a problem whose starting tile is less less than tile_idx.
while (group_tile_end <= this->tile_idx)
{
group_problem_start += kThreadsPerWarp;
if (group_problem_start > this->params.problem_count)
{
return false;
}
// Since `group_tile_start` is a reference to `this->problem_tile_start`, this
// also sets `this->problem_tile_start`. The fact that `this->problem_tile_start`
// is also set here is used later in `next_tile`.
group_tile_start = group_tile_end;
int lane_idx = threadIdx.x % kThreadsPerWarp;
int32_t lane_problem = group_problem_start + lane_idx;
// Compute the number of tiles in the problem assigned to each thread.
problem_ending_tile = 0;
if (lane_problem < this->params.problem_count)
{
cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem);
cutlass::gemm::GemmCoord grid = this->grid_shape(problem);
problem_ending_tile = this->tile_count(grid);
}
// Compute a warp-wide inclusive prefix sum to compute the ending tile index of
// each thread's problem.
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kThreadsPerWarp; i <<= 1)
{
int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i);
if (lane_idx >= i)
{
problem_ending_tile += val;
}
}
// The total tile count for this group is now in the final position of the prefix sum
int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
problem_ending_tile += group_tile_start;
group_tile_end += tiles_in_group;
}
// The next problem to process is the first one that does not have ending tile position
// that is greater than or equal to tile index.
int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx));
this->problem_idx = group_problem_start + problem_idx_in_group;
// The starting tile for this problem is the ending tile of the previous problem. In cases
// where `problem_idx_in_group` is the first problem in the group, we do not need to reset
// `problem_tile_start`, because it is set to the previous group's ending tile in the while
// loop above.
if (problem_idx_in_group > 0)
{
this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1);
}
return true;
}
static size_t get_workspace_size(
cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count)
{
return 0;
}
static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count,
int32_t block_count, void* host_workspace_ptr)
{
}
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/tensor.hpp"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cutlass/workspace.h"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel
{
///////////////////////////////////////////////////////////////////////////////
template <class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
class GemmUniversalGated<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
cute::enable_if_t<
cute::is_base_of_v<KernelTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule>
&& CollectiveMainloop_::isGated>>
{
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using ClusterShape = typename DispatchPolicy::ClusterShape;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
using Activation = typename CollectiveMainloop::Activation;
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(ArchTag::kMinComputeCapability >= 90);
using TileSchedulerTag = TileScheduler_;
using TileScheduler =
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
using TileSchedulerParams = typename TileScheduler::Params;
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
static constexpr uint32_t MaxThreadsPerBlock
= CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
// 1 stage ordered sequence between mainloop and epilogue producer load threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
// Kernel level shared memory storage
struct SharedStorage
{
struct TensorStorage : cute::aligned_struct<128>
{
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
MainloopTensorStorage mainloop;
EpilogueTensorStorage epilogue;
} tensors;
struct PipelineStorage : cute::aligned_struct<16>
{
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) EpiLoadPipelineStorage epi_load;
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
} pipelines;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
// Device side arguments
struct Arguments
{
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params
{
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopParams mainloop{};
EpilogueParams epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerParams scheduler{};
void* workspace{nullptr};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
static Params to_underlying_arguments(Arguments const& args, void* workspace)
{
CUTLASS_TRACE_HOST("to_underlying_arguments():");
auto problem_shape = args.problem_shape;
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
// // swap M/N
// get<0>(problem_shape) = get<1>(args.problem_shape);
// get<1>(problem_shape) = get<0>(args.problem_shape);
// }
auto problem_shape_MNKL = append<4>(problem_shape, 1);
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = args.hw_info.sm_count;
if (sm_count <= 0)
{
CUTLASS_TRACE_HOST(
" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
// Calculate workspace pointers
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
size_t workspace_offset = 0;
void* scheduler_workspace = workspace_ptr;
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* epilogue_workspace = workspace_ptr + workspace_offset;
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* mainloop_workspace = nullptr;
// Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used
// in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means
// subtile will not be used, therefore separate reduction will not be enabled.
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{},
ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles);
return {args.mode, problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info,
scheduler, workspace};
}
static bool can_implement(Arguments const& args)
{
bool implementable = (args.mode == GemmUniversalMode::kGemm)
or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable)
{
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
}
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
implementable &= TileScheduler::can_implement(args.scheduler);
return implementable;
}
static size_t get_workspace_size(Arguments const& args)
{
size_t workspace_size = 0;
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
return workspace_size;
}
static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr,
cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr)
{
Status status = Status::kSuccess;
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
size_t workspace_offset = 0;
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(args.scheduler,
workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups,
NumEpilogueSubTiles);
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess)
{
return status;
}
status = CollectiveEpilogue::initialize_workspace(
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess)
{
return status;
}
return status;
}
// Computes the kernel launch grid shape based on runtime parameters
static dim3 get_grid_shape(Params const& params)
{
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
TileSchedulerArguments args{};
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>)
{
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
}
args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
? TileScheduler::RasterOrderOptions::AlongN
: TileScheduler::RasterOrderOptions::AlongM;
return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args);
}
static dim3 get_block_shape()
{
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void operator()(Params const& params, char* smem_buf)
{
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads.");
static_assert(size<0>(TileShape{}) >= 128,
"Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension.");
static_assert(cute::rank(StrideA{}) == 3,
"StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3,
"StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3,
"StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3,
"StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */
enum class WarpGroupRole
{
Producer = 0,
Consumer0 = 1,
Consumer1 = 2
};
enum class ProducerWarpRole
{
Mainloop = 0,
Warp1 = 1,
Epilogue = 2,
Warp3 = 3
};
// Kernel level shared memory storage
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
int thread_idx = int(threadIdx.x);
int lane_idx = canonical_lane_idx();
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
int mma_thread_idx = thread_idx % size(TiledMma{});
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate)
{
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
// Mainloop Load pipeline
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
typename MainloopPipeline::Params mainloop_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop)
{
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
{
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = size(TiledMma{});
mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
// Epilogue Load pipeline
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue)
{
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
{
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
}
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = size(TiledMma{});
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
// Epilogue Store pipeline
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
params_load_order_barrier.group_size = NumThreadsPerWarp;
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier);
// Initialize starting pipeline states for the collectives
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
auto cluster_wait_fn = []()
{
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer thread blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1)
{
cute::cluster_arrive_relaxed();
return []() { cute::cluster_wait(); };
}
else
{
__syncthreads();
return []() {}; // do nothing
}
}();
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
// Get the appropriate blocks for this thread block -- potential for thread block locality
TiledMma tiled_mma;
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
TileScheduler scheduler{params.scheduler};
auto work_tile_info = scheduler.get_current_work();
// In a warp specialized kernel, collectives expose data movement and compute operations separately
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
// Prepare and partition the input tensors. Expects a tuple of tensors where:
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 3,
"Output of load_init must have at least three elements (A, B, Aux)");
// Extract out partitioned A and B.
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
// Get pipeline stage increments from tensor shapes
auto k_tile_count = size<3>(gA_mkl);
// Wait for all thread blocks in the Cluster
cluster_wait_fn();
if (warp_group_role == WarpGroupRole::Producer)
{
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
// Mainloop Producer Warp
if (producer_warp_role == ProducerWarpRole::Mainloop)
{
bool do_load_order_arrive = true;
while (work_tile_info.is_valid())
{
if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info))
{
work_tile_info = fetch_next_work(work_tile_info, scheduler);
continue;
}
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
// Get the number of K tiles to compute for this work as well as the starting K tile offset of the
// work.
auto work_k_tile_count
= TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info);
auto k_tile_iter
= cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, block_rank_in_cluster,
shared_storage.tensors.mainloop);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state.advance(work_k_tile_count);
// Signal for the epilogue load warp to begin
if (do_load_order_arrive)
{
load_order_barrier.arrive();
do_load_order_arrive = false;
}
// Get next work tile
work_tile_info = fetch_next_work(work_tile_info, scheduler);
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
} // Mainloop Producer Warp End
// Epilogue Producer Warp
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed())
{
while (work_tile_info.is_valid())
{
if (!TileScheduler::requires_separate_reduction(params.scheduler))
{
load_order_barrier.wait();
}
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler))
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
epi_load_pipe_producer_state = collective_epilogue.load(epi_load_pipeline,
epi_load_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx,
shared_storage.tensors.epilogue, work_tile_info.reduction_subtile_idx());
}
// Get next work tile
work_tile_info = fetch_next_work(work_tile_info, scheduler);
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
} // Epilogue Producer Warp End
} // Producer Warp Group End
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
{
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
// Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it
bool do_store_tail = false;
float scale_d0 = params.mainloop.scale_d0;
float scale_d1 = params.mainloop.scale_d1;
while (work_tile_info.is_valid())
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
auto work_k_tile_count
= TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
// Allocate the accumulators for the (M,N) blk_shape
//
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
auto accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
auto accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info))
{
collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0,
accumulators1, work_k_tile_count, mma_thread_idx, shared_storage.tensors.mainloop,
params.mainloop);
// Make sure the math instructions are done and free buffers before entering the epilogue
collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, work_k_tile_count);
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state.advance(work_k_tile_count);
}
// Index of warp group within consumer warp groups
int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups;
// Perform reduction across splits, if needed
TileScheduler::fixup(
params.scheduler, work_tile_info, accumulators0, NumMmaWarpGroups, consumer_warp_group_idx);
TileScheduler::fixup(
params.scheduler, work_tile_info, accumulators1, NumMmaWarpGroups, consumer_warp_group_idx);
Activation elt_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators0); i++)
{
accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]);
}
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler))
{
// Epilogue and write to gD
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next]
= collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline,
epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue,
work_tile_info.reduction_subtile_idx());
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next;
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next;
do_store_tail = true;
}
// Get next work tile
work_tile_info = fetch_next_work(work_tile_info, scheduler);
} // Scheduler work fetch loop
if (do_store_tail)
{
collective_epilogue.store_tail(
epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, epi_store_pipe_producer_state);
}
} // Consumer Warp Groups End
#endif
}
private:
// Kernel helper function to get next work unit
CUTLASS_DEVICE
typename TileScheduler::WorkTileInfo fetch_next_work(
typename TileScheduler::WorkTileInfo& work_tile_info, TileScheduler& scheduler) const
{
// Check whether we should continue on with the current work unit. If this is the case,
// the work unit will have been updated in continue_current_work to reflect the new
// tile to be computed.
if (scheduler.continue_current_work(work_tile_info))
{
return work_tile_info;
}
// Get next work tile
scheduler.advance_to_next_work();
return scheduler.get_current_work();
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cutlass/workspace.h"
#include "cute/tensor.hpp"
#include "cute/util/debug.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel
{
///////////////////////////////////////////////////////////////////////////////
template <class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
class GemmUniversalGated<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
cute::enable_if_t<
cute::is_base_of_v<KernelTmaWarpSpecializedPingpong, typename CollectiveMainloop_::DispatchPolicy::Schedule>
&& CollectiveMainloop_::isGated>>
{
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using ClusterShape = typename DispatchPolicy::ClusterShape;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
using Activation = typename CollectiveMainloop::Activation;
static_assert(ArchTag::kMinComputeCapability >= 90);
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(!cute::is_same_v<TileScheduler_, StreamKScheduler>,
"Ping-pong kernel does not currently support stream-K scheduler.");
using TileSchedulerTag = TileScheduler_;
using TileScheduler =
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
using TileSchedulerParams = typename TileScheduler::Params;
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups = 2;
static constexpr uint32_t MaxThreadsPerBlock
= CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
// 1 stage ordered sequence between mainloop and epilogue producer load threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
// Order Sequence barrier with two stages: one for Mainloop and one for Epilogue
static constexpr uint32_t StagesPerMathWarpGroup = 2;
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<StagesPerMathWarpGroup, NumMmaWarpGroups>;
// Kernel level shared memory storage
struct SharedStorage
{
struct TensorStorage : cute::aligned_struct<128>
{
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
MainloopTensorStorage mainloop;
EpilogueTensorStorage epilogue;
} tensors;
struct PipelineStorage : cute::aligned_struct<16>
{
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) EpiLoadPipelineStorage epi_load;
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
} pipelines;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
// Device side arguments
struct Arguments
{
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params
{
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopParams mainloop{};
EpilogueParams epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerParams scheduler{};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
static Params to_underlying_arguments(Arguments const& args, void* workspace)
{
CUTLASS_TRACE_HOST("to_underlying_arguments():");
(void) workspace;
auto problem_shape = args.problem_shape;
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
// // swap M/N
// get<0>(problem_shape) = get<1>(args.problem_shape);
// get<1>(problem_shape) = get<0>(args.problem_shape);
// }
auto problem_shape_MNKL = append<4>(problem_shape, 1);
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = args.hw_info.sm_count;
if (sm_count <= 0)
{
CUTLASS_TRACE_HOST(
" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
// Calculate workspace pointers
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
size_t workspace_offset = 0;
void* scheduler_workspace = workspace_ptr;
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* epilogue_workspace = workspace_ptr + workspace_offset;
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* mainloop_workspace = nullptr;
return {args.mode, problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info,
TileScheduler::to_underlying_arguments(
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace)};
}
static bool can_implement(Arguments const& args)
{
bool implementable = (args.mode == GemmUniversalMode::kGemm)
or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable)
{
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
}
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
implementable &= TileScheduler::can_implement(args.scheduler);
return implementable;
}
static size_t get_workspace_size(Arguments const& args)
{
size_t workspace_size = 0;
workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
return workspace_size;
}
static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr,
cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr)
{
Status status = Status::kSuccess;
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
size_t workspace_offset = 0;
status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(args.scheduler,
workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess)
{
return status;
}
status = CollectiveEpilogue::initialize_workspace(
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess)
{
return status;
}
return status;
}
// Computes the kernel launch grid shape based on runtime parameters
static dim3 get_grid_shape(Params const& params)
{
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
TileSchedulerArguments args{};
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>)
{
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
}
args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
? TileScheduler::RasterOrderOptions::AlongN
: TileScheduler::RasterOrderOptions::AlongM;
return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args);
}
static dim3 get_block_shape()
{
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void operator()(Params const& params, char* smem_buf)
{
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(cute::rank(StrideA{}) == 3,
"StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3,
"StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3,
"StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3,
"StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
enum class WarpGroupRole
{
Producer = 0,
Consumer0 = 1,
Consumer1 = 2
};
enum class ProducerWarpRole
{
Mainloop = 0,
Warp1 = 1,
Epilogue = 2,
Warp3 = 3
};
// Kernel level shared memory storage
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
int thread_idx = int(threadIdx.x);
int lane_idx = canonical_lane_idx();
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate)
{
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
// Mainloop Load pipeline
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
typename MainloopPipeline::Params mainloop_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop)
{
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
{
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
// Epilogue Load pipeline
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue)
{
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
{
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
}
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
// Epilogue Store pipeline
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
params_load_order_barrier.group_size = NumThreadsPerWarp;
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier);
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
// DMA Load WG will not participate in these Ordered Barrier syncs
params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast<int>(WarpGroupRole::Consumer0);
params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group
MathWarpGroupOrderBarrier math_wg_order_barrier(
shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier);
// Initialize starting pipeline states for the collectives
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
auto cluster_wait_fn = [&]()
{
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer thread blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1)
{
cute::cluster_arrive_relaxed();
return []() { cute::cluster_wait(); };
}
else
{
__syncthreads();
return []() {}; // do nothing
}
}();
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
// Get the appropriate blocks for this thread block -- potential for thread block locality
TiledMma tiled_mma;
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
// In a warp specialized kernel, collectives expose data movement and compute operations separately
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
// Prepare and partition the input tensors. Expects a tuple of tensors where:
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 3,
"Output of load_init must have at least three elements (A, B, Aux)");
// Extract out partitioned A and B.
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
// Get pipeline stage increments from tensor shapes
auto k_tile_count = size<3>(gA_mkl);
auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
TileScheduler scheduler{params.scheduler};
if (warp_group_role == WarpGroupRole::Consumer1)
{
// Advance 2nd Math WG to the next work tile for the startup
scheduler.advance_to_next_work();
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
mainloop_pipe_consumer_state.advance(k_tile_count);
epi_load_pipe_consumer_state.advance(c_tile_count);
epi_store_pipe_producer_state.advance(d_tile_count);
}
auto work_tile_info = scheduler.get_current_work();
// Wait for all thread blocks in the Cluster
cluster_wait_fn();
if (warp_group_role == WarpGroupRole::Producer)
{
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
// Mainloop Producer Warp
if (producer_warp_role == ProducerWarpRole::Mainloop)
{
bool do_load_order_arrive = true;
while (work_tile_info.is_valid())
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, block_rank_in_cluster,
shared_storage.tensors.mainloop);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state.advance(k_tile_count);
// Signal for the epilogue load warp to begin
if (do_load_order_arrive)
{
load_order_barrier.arrive();
do_load_order_arrive = false;
}
// Get next work tile
scheduler.advance_to_next_work();
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
} // Mainloop Producer Warp End
// Epilogue Producer Warp
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed())
{
load_order_barrier.wait();
while (work_tile_info.is_valid())
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
epi_load_pipe_producer_state
= collective_epilogue.load(epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL,
blk_shape, blk_coord, tiled_mma, lane_idx, shared_storage.tensors.epilogue);
// Get next work tile
scheduler.advance_to_next_work();
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
} // Epilogue Producer Warp End
} // Producer Warp Group End
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
{
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
float scale_d0 = params.mainloop.scale_d0;
float scale_d1 = params.mainloop.scale_d1;
while (work_tile_info.is_valid())
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
// Allocate the accumulators for the (M,N) blk_shape
Tensor accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
Tensor accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
// Order two Math WG's MMA one after the other, helps hide Epilogue
math_wg_order_barrier.wait();
collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, accumulators1,
k_tile_count, warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop);
// Cue for next Math WG's MMA to start
math_wg_order_barrier.arrive();
// Make sure the math instructions are done and free buffers before entering the epilogue
collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count);
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups);
Activation elt_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators0); i++)
{
accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]);
}
// Order two Math WG's Epilogue one after the other
math_wg_order_barrier.wait();
// Epilogue and write to gD
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next]
= collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline,
epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
tiled_mma, warp_group_thread_idx, shared_storage.tensors.epilogue);
// TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels
// we need to wait for all TMA stores to complete before issuing consumer order barrier arrives
// to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer.
auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_]
= collective_epilogue.store_tail(epi_load_pipeline, epi_load_pipe_consumer_state_next,
epi_store_pipeline, epi_store_pipe_producer_state_next);
// Update starting load/store pipeline states for the next tile
// state has already been incremented by 1 tile in collective calls, advance once again for ping pong
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_;
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_;
epi_load_pipe_consumer_state.advance(c_tile_count);
epi_store_pipe_producer_state.advance(d_tile_count);
// Cue for next Math WG's Epilogue to start
math_wg_order_barrier.arrive();
// Get next work tile
scheduler.advance_to_next_work(NumMmaWarpGroups);
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
} // Consumer Warp Groups End
#endif
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
bool Transposed = false>
struct SplitkGemmGrouped
{
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static bool const kTransposed = Transposed;
// Optional transpose
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
kTransposed>;
// Public-facing type definitions related to operand element type, layout, and complex conjugate
// operation. Must interact with the 'kTransposed' notion.
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename MapArguments::LayoutC;
using ElementFinalOutput = typename MapArguments::ElementA;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = MapArguments::kAlignmentA;
static int const kAlignmentB = MapArguments::kAlignmentB;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor
= GemmGroupedProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
//
// Structures
//
/// Argument structure
struct Arguments
{
//
// Data members
//
GemmCoord* problem_sizes;
int problem_count;
int threadblock_count;
typename EpilogueOutputOp::Params output_op;
ElementA** ptr_A;
ElementB** ptr_B;
ElementFinalOutput** ptr_C;
ElementFinalOutput** ptr_D;
typename LayoutA::Stride::LongIndex* lda;
typename LayoutB::Stride::LongIndex* ldb;
typename LayoutC::Stride::LongIndex* ldc;
typename LayoutC::Stride::LongIndex* ldd;
// Only used by device-level operator
GemmCoord* host_problem_sizes;
// splitK
int split_k_slices;
int64_t* splitk_buffer_offsets;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments()
: problem_count(0)
, threadblock_count(0)
, ptr_A(nullptr)
, ptr_B(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, lda(nullptr)
, ldb(nullptr)
, ldc(nullptr)
, ldd(nullptr)
, host_problem_sizes(nullptr)
, split_k_slices(1)
, splitk_buffer_offsets(nullptr)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count,
typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C,
ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda,
typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc,
typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices,
int64_t* splitk_buffer_offsets)
: problem_sizes(problem_sizes)
, problem_count(problem_count)
, threadblock_count(threadblock_count)
, output_op(output_op)
, ptr_A(ptr_A)
, ptr_B(ptr_B)
, ptr_C(ptr_C)
, ptr_D(ptr_D)
, lda(lda)
, ldb(ldb)
, ldc(ldc)
, ldd(ldd)
, host_problem_sizes(host_problem_sizes)
, split_k_slices(split_k_slices)
, splitk_buffer_offsets(splitk_buffer_offsets)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params
{
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
typename EpilogueOutputOp::Params output_op;
ElementA** ptr_A;
ElementB** ptr_B;
ElementFinalOutput** ptr_C;
ElementFinalOutput** ptr_D;
ElementC* ptr_C_split;
ElementC* ptr_D_split;
typename LayoutA::Stride::LongIndex* lda;
typename LayoutB::Stride::LongIndex* ldb;
typename LayoutC::Stride::LongIndex* ldc;
typename LayoutC::Stride::LongIndex* ldd;
//
// Methods
//
// splitk
GemmCoord grid_tiled_shape;
int swizzle_log_tile;
int gemm_k_size;
GemmCoord* host_problem_sizes;
int split_k_slices;
int64_t* splitk_buffer_offsets;
CUTLASS_HOST_DEVICE
Params()
: ptr_A(nullptr)
, ptr_B(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, ptr_C_split(nullptr)
, ptr_D_split(nullptr)
, lda(nullptr)
, ldb(nullptr)
, ldc(nullptr)
, ldd(nullptr)
, swizzle_log_tile(0)
, gemm_k_size(0)
, host_problem_sizes(nullptr)
, split_k_slices(1)
, splitk_buffer_offsets(nullptr)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
: problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count)
, host_problem_sizes(args.host_problem_sizes)
, threadblock_count(args.threadblock_count)
, output_op(args.output_op)
, ptr_A(args.ptr_A)
, ptr_B(args.ptr_B)
, ptr_C(args.ptr_C)
, ptr_D(args.ptr_D)
, ptr_C_split((ElementC*) workspace)
, ptr_D_split((ElementC*) workspace)
, lda(args.lda)
, ldb(args.ldb)
, ldc(args.ldc)
, ldd(args.ldd)
, split_k_slices(args.split_k_slices)
, splitk_buffer_offsets(args.splitk_buffer_offsets)
{
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0],
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices);
swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape);
// only support same k
int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK;
int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k();
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
}
CUTLASS_HOST_DEVICE
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
{
problem_visitor =
typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count);
threadblock_count = args.threadblock_count;
output_op = args.output_op;
ptr_A = args.ptr_A;
ptr_B = args.ptr_B;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
ptr_C_split = workspace;
ptr_D_split = workspace;
lda = args.lda;
ldb = args.ldb;
ldc = args.ldc;
ldd = args.ldd;
}
};
/// Shared memory storage structure
struct SharedStorage
{
union
{
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
} kernel;
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
public:
//
// Methods
//
CUTLASS_DEVICE
SplitkGemmGrouped() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
{
return Status::kSuccess;
}
static Status can_implement(Arguments const& args)
{
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
//
// Problem visitor.
//
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile())
{
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
// Load element pointers. Exchange pointers and strides if working on the transpose
ElementA* ptr_A
= reinterpret_cast<ElementA*>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
ElementB* ptr_B
= reinterpret_cast<ElementB*>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0);
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
threadblock_tile_offset.k() * params.gemm_k_size,
};
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k;
if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k())
{
problem_size_k = problem_size.k();
}
else
{
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = canonical_warp_idx_sync();
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
ElementC* ptr_C = params.ptr_C_split;
ElementC* ptr_D = params.ptr_D_split;
LayoutC layout_C(params.ldc[problem_idx]);
LayoutC layout_D(params.ldd[problem_idx]);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// assume identity swizzle
MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n());
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C);
iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C);
iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
// We need to distinguish here, since we want volta support. It is too much effort
// to write shared memory iterators that are probably needed for volta to function
// properly. As a result, we allow converters both after the LDG (for volta) and after
// the LDS for Turing+.
template <
/// Iterator for B matrix in global memory
typename IteratorB,
/// Warp level Mma
typename MmaOperator,
/// Math operation perform by warp level operator
typename MathOperator>
struct SetConverters
{
};
// Dequantize after LDG, so set transforms accordingly
template <
/// Iterator for B matrix in global memory
typename IteratorB,
/// Mma Policy
typename MmaOperator>
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd>
{
using TransformAfterLDG
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
typename IteratorB::Element, IteratorB::Fragment::kElements>;
using TransformAfterLDS = NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
typename MmaOperator::ArchMmaOperator::ElementB, MmaOperator::FragmentB::kElements>;
};
// Dequantize after LDS, so set transforms accordingly
template <
/// Iterator for B matrix in global memory
typename IteratorB,
/// Mma Policy
typename MmaOperator>
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAddDequantizeInterleavedBToA>
{
using TransformAfterLDG = NumericArrayConverter<typename IteratorB::Element, typename IteratorB::Element,
IteratorB::Fragment::kElements>;
using TransformAfterLDS
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
typename TransformAfterLDG::result_type::Element, MmaOperator::FragmentB::kElements>;
};
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale_,
/// Layout for the scale operand
typename LayoutScale_,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
///
typename Enable = void>
struct DqMma;
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h"
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
typename Enable = void>
struct DefaultScaleIteratorsMultistage;
// Fine grained iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<isFinegrained(QuantOp)>>
{
using IteratorScale
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
Layout, 0, Alignment>;
using SmemIteratorScale = IteratorScale;
};
// Per column iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<!isFinegrained(QuantOp)>>
{
// ThreadMap for scale iterator
static_assert((MmaShape::kN % Alignment) == 0, "");
private:
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
MmaShape::kN / Alignment, Alignment>;
public:
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
using SmemIteratorScale = IteratorScale;
};
////////////////////////////////////////////////////////////////////////////////
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Stages in GEMM
int kStages,
/// Operator performed by GEMM
typename Operator_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
kStages, Operator_, SharedMemoryClear,
typename platform::enable_if<(
ArchTag::kMinComputeCapability >= 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|| platform::is_same<ElementA, float_e4m3_t>::value,
"Element A must be fp16, fp8 or bf16");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, std::max(kStages, 3),
Operator, false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, ElementB, LayoutB, 0, ThreadMapB,
AccessTypeB>;
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
};
// Specialization to handle column major interleave B
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Stages in GEMM
int kStages,
/// Operator performed by GEMM
typename Operator_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
kStages, Operator_, SharedMemoryClear,
typename platform::enable_if<(
ArchTag::kMinComputeCapability >= 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|| platform::is_same<ElementA, float_e4m3_t>::value,
"Element A must be fp16, fp8 or bf16");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
AccessTypeA>;
private:
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
using GmemIteratorShape
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<GmemIteratorShape, ElementB,
layout::ColumnMajor, 0, GmemThreadMapB, AccessTypeB>;
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
typename Enable = void>
struct DefaultScaleIteratorsPipelined;
// Fine grained iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<isFinegrained(QuantOp)>>
{
private:
using SmemScaleType = half_t;
public:
using IteratorScale
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
Layout, 0, Alignment>;
using SmemIteratorScale
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>,
SmemScaleType, Layout, 0, Alignment>;
};
// Per column iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<!isFinegrained(QuantOp)>>
{
static_assert((MmaShape::kN % Alignment) == 0, "");
private:
// ThreadMap for scale iterator
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
MmaShape::kN / Alignment, Alignment>;
using SmemScaleType = half_t;
public:
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
using SmemIteratorScale
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>, SmemScaleType,
Layout, 0, IteratorScaleThreadMap, Alignment>;
};
////////////////////////////////////////////////////////////////////////////////
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator_>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
Operator_, SharedMemoryClearOption::kNone,
typename platform::enable_if<(
ArchTag::kMinComputeCapability < 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
using MmaCoreElementA = half_t;
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
// Define the MmaCore components
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
MmaCoreElementA, LayoutA, MmaCoreElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, 2,
Operator>;
// Define iterators over tiles from the A operand
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, ElementB, LayoutB, 0,
typename MmaCore::IteratorThreadMapB, kAlignmentB>;
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
};
// Specialization to handle column major interleave B
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator_>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
Operator_, SharedMemoryClearOption::kNone,
typename platform::enable_if<(
ArchTag::kMinComputeCapability < 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
using MmaCoreElementA = half_t;
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
// Define the MmaCore components
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
MmaCoreElementA, LayoutA, MmaCoreElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor,
OperatorClass, 2, Operator>;
// Define iterators over tiles from the A operand
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
private:
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
using GmemIteratorShape
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<GmemIteratorShape, ElementB,
layout::ColumnMajor, 0, GmemThreadMapB, kAlignmentB>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap
= transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>;
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
/// (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
/// (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
#ifdef ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
/// (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t,
layout::RowMajor, kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
ThreadblockShape, WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
#endif
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
// large tile when not enough shared mem is present to do 3+ stage
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB>
struct DefaultMma<half_t, LayoutA, kAlignmentA, half_t, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor,
arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
SharedMemoryClear, GatherA, GatherB>
{
// Define the MmaCore components
// 3 is used on purpose here to trigger components for mma multistage
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
half_t, LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA, AccessTypeA,
GatherA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB, AccessTypeB,
GatherB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB>
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
SharedMemoryClear, GatherA, GatherB>
{
private:
// Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS.
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
using MmaElementA = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
using MmaElementB = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
public:
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, MmaElementA,
LayoutA, MmaElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, Operator>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, bfloat16_t, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>;
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, bfloat16_t, LayoutB, 0,
typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
layout::RowMajor, typename MmaCore::MmaPolicy>;
};
// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
// large tile when not enough shared mem is present to do 3+ stage
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB>
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator,
false, SharedMemoryClear, GatherA, GatherB>
{
// Define the MmaCore components
// 3 is used on purpose here to trigger components for mma multistage
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
AccessTypeA, GatherA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
AccessTypeB, GatherB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment