Unverified Commit 3ee62235 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

revert the MoE dependence (#3230)

parent 9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/weight_only_quant_op.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
// SFINAE trick so I can keep the same loop code for Volta and dispatch to the
// correct warp level mma. On volta, all data is stored to shared memory as FP16.
template <typename WarpMma, int kExpansionFactor = 1>
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C,
int const warp_tileB_k_offset)
{
warp_mma(D, A, B, C);
}
template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B,
typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset)
{
warp_mma(D, A, B, C, warp_tileB_k_offset);
}
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// The type of the scales
typename ElementScale_,
/// Number of stages,
int Stages,
/// The dequantizing op to be performed.
WeightOnlyQuantOp DequantOp,
/// Used for partial specialization,
typename Enable = bool>
class DqMmaBase
{
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
///< Type of the scale to be loaded
using ElementScale = ElementScale_;
static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, "");
// Finegrained scales get streamed in via cp.async
static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1;
// We always have scales.
static constexpr int ScaleElementsPerStage = Shape::kN;
// We sometimes have a bias
static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
static constexpr int kNumKIterationsPerWarpBLoad
= Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
/// Number of stages
static int const kStages = Stages;
/// Tensor reference to the A operand
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage
{
public:
//
// Type definitions
//
/// Shape of the A matrix operand in shared memory
using ShapeA
= MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB
= MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
/// Shape of the shared memory buffer for the scales for the B matrix.
using ShapeScale = MatrixShape<ScalebiasStages, ScaleElementsPerStage>;
/// Shape of the shared memory buffer for the biases of the B matrix.
using ShapeZero = MatrixShape<ScalebiasStages, BiasElementsPerStage>;
public:
//
// Data members
//
/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
/// Buffer to hold scales for threadblock
AlignedBuffer<ElementScale, ShapeScale::kCount> operand_scale;
/// Buffer to hold scales for threadblock
AlignedBuffer<ElementScale, ShapeZero::kCount> operand_zero;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA()
{
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB()
{
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref()
{
return TensorRefA{operand_A.data(), LayoutA()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref()
{
return TensorRefB{operand_B.data(), LayoutB()};
}
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage& shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx)
, warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
{
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type for the scales
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Used for partial specialization
typename Enable = void>
class DqMmaMultistage;
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h"
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Iterators over scales in global memory
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Layout of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Converter for B matrix applied immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
class DqMmaMultistage<Shape_, IteratorA_, SmemIteratorA_, CacheOpA, IteratorB_, SmemIteratorB_, CacheOpB,
IteratorScale_, SmemIteratorScale_, ElementC_, LayoutC_, Policy_, Stages, TransformBAfterLDS_, QuantOp_,
SharedMemoryClear, std::enable_if_t<isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>
{
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
using TransformBAfterLDS = TransformBAfterLDS_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale,
LayoutScale, 32, QuantOp>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, "");
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
/// Internal structure exposed for introspection.
struct Detail
{
static_assert(Base::kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA
= (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB
= (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
};
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage,
/// The group size for quantization
int const group_size,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1)
{
static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1.");
typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale();
typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero();
typename IteratorScale::AccessType* smem_scale_ptr
= reinterpret_cast<typename IteratorScale::AccessType*>(this->smem_iterator_scale_.get_scale());
typename IteratorScale::AccessType* smem_zero_ptr
= reinterpret_cast<typename IteratorScale::AccessType*>(this->smem_iterator_scale_.get_zero());
int const kSrcBytes = sizeof_bits<typename IteratorScale::Element>::value * IteratorScale::kAlignment / 8;
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid());
if (gmem_zero_ptr != nullptr)
{
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid());
}
if (iterator_scale.group_size_ == 64)
{
iterator_scale.add_tile_offset({1, 0});
}
else if (iterator_scale.group_size_ == 128)
{
if constexpr (Shape::kK == 128)
{
iterator_scale.add_tile_offset({1, 0});
}
else if constexpr (Shape::kK == 64)
{
if (iterator_scale.row_groupsize64_ & 0x1)
{
iterator_scale.add_tile_offset({1, 0});
}
}
else
{
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
}
}
iterator_scale.row_groupsize64_++;
this->smem_iterator_scale_.add_tile_offset({1, 0});
}
CUTLASS_DEVICE
void copy_tiles_and_advance(
IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
{
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j)
{
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
{
auto gmem_ptr = iterator_A.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
{
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
}
else
{
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
}
++iterator_A;
}
++this->smem_iterator_A_;
}
}
iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j)
{
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
{
auto gmem_ptr = iterator_B.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
{
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
}
else
{
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
}
++iterator_B;
}
++this->smem_iterator_B_;
}
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC& accum,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< iterator over scale operand in global memory
IteratorScale iterator_scale,
///< initial value of accumulator
FragmentC const& src_accum)
{
//
// Prologue
//
TransformBAfterLDS lds_converter;
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations)
{
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_scale.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
{
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
iterator_B.set_iteration_index(0);
this->smem_iterator_B_.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
{
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
++iterator_B;
}
++this->smem_iterator_B_;
}
copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations);
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
// Perform accumulation in the 'd' output operand
accum = src_accum;
//
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
// so that all accumulator elements outside the GEMM footprint are zero.
//
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage)
{
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
typename IteratorA::AccessType zero_A;
zero_A.clear();
last_smem_iterator_A.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
*dst_ptr = zero_A;
++last_smem_iterator_A;
}
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
zero_B.clear();
last_smem_iterator_B.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
*dst_ptr = zero_B;
++last_smem_iterator_B;
}
}
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
typename Dequantizer::FragmentScale warp_frag_scales;
typename Dequantizer::FragmentZero warp_frag_zeros;
Operator warp_mma;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
warp_dequantizer_.add_pointer_offset(Shape::kN);
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_scale.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-Base::kStages + 1);)
{
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros);
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
using Converter
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
warp_tileB_k_compute_offset);
// Issue global->shared copies for the this stage
if (warp_mma_k < Base::kWarpGemmIterations - 1)
{
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
// This is the first group of a given stage, so we issue the loads for the B scales immediately.
if (group_start_iteration_B == 0)
{
copy_scales_and_advance(iterator_scale);
}
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations)
{
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
// #committed)
arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1))
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
}
else
{
++smem_write_stage_idx;
}
if (smem_read_stage_idx == (Base::kStages - 1))
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
smem_read_stage_idx = 0;
}
else
{
++smem_read_stage_idx;
}
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_scale.clear_mask(gemm_k_iterations == 0);
}
}
// Load the scale needed for the next tile iteration.
warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
// Update internal pointer to set of scales in shared memory.
warp_dequantizer_.add_pointer_offset(Shape::kN);
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
{
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Iterators over scales in global memory
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Layout of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
class DqMmaMultistage<Shape_, IteratorA_, SmemIteratorA_, CacheOpA, IteratorB_, SmemIteratorB_, CacheOpB,
IteratorScale_, SmemIteratorScale_, ElementC_, LayoutC_, Policy_, Stages, TransformBAfterLDS_, QuantOp_,
SharedMemoryClear, std::enable_if_t<!isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>
{
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
using TransformBAfterLDS = TransformBAfterLDS_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
//
// Dependent types
//
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale,
LayoutScale, 32, QuantOp>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
/// Internal structure exposed for introspection.
struct Detail
{
static_assert(Base::kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA
= (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB
= (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
};
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage,
///< Group size for quantization. Not used by this main loop since it assumes per-column
int const group_size,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
void copy_tiles_and_advance(
IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
{
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j)
{
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
{
auto gmem_ptr = iterator_A.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
{
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
}
else
{
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
}
++iterator_A;
}
++this->smem_iterator_A_;
}
}
iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j)
{
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
{
auto gmem_ptr = iterator_B.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
{
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
}
else
{
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
}
++iterator_B;
}
++this->smem_iterator_B_;
}
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC& accum,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< iterator over scale operand in global memory
IteratorScale iterator_scale,
///< initial value of accumulator
FragmentC const& src_accum)
{
//
// Prologue
//
TransformBAfterLDS lds_converter;
// NOTE - switch to ldg.sts
// Issue this first, so cp.async.commit_group will commit this load as well.
// Note: we do not commit here and this load will commit in the same group as
// the first load of A.
FragmentScale tb_frag_scales;
tb_frag_scales.clear();
iterator_scale.load(tb_frag_scales);
this->smem_iterator_scale_.store(tb_frag_scales);
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations)
{
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
{
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
iterator_B.set_iteration_index(0);
this->smem_iterator_B_.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
{
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
++iterator_B;
}
++this->smem_iterator_B_;
}
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
// Perform accumulation in the 'd' output operand
accum = src_accum;
//
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
// so that all accumulator elements outside the GEMM footprint are zero.
//
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage)
{
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
typename IteratorA::AccessType zero_A;
zero_A.clear();
last_smem_iterator_A.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
*dst_ptr = zero_A;
++last_smem_iterator_A;
}
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
zero_B.clear();
last_smem_iterator_B.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
*dst_ptr = zero_B;
++last_smem_iterator_B;
}
}
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
typename Dequantizer::FragmentScale warp_frag_scales;
Operator warp_mma;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
warp_dequantizer_.load(warp_frag_scales);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-Base::kStages + 1);)
{
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
using Converter
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
warp_tileB_k_compute_offset);
// Issue global->shared copies for the this stage
if (warp_mma_k < Base::kWarpGemmIterations - 1)
{
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations)
{
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
// #committed)
arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1))
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
}
else
{
++smem_write_stage_idx;
}
if (smem_read_stage_idx == (Base::kStages - 1))
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
smem_read_stage_idx = 0;
}
else
{
++smem_read_stage_idx;
}
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
}
}
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
{
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
#include "cutlass_extensions/gemm_configs.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Data type for the scales
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Converter for B matrix applied immediately after the LDG (before STS)
typename TransformBAfterLDG_,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_,
/// Used for partial specialization
typename Enable = void>
class DqMmaPipelined;
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h"
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
#include "cutlass_extensions/gemm_configs.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Iterators over scales in global memory
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Layout of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Converter for B matrix applied immediately after the LDG (before STS)
typename TransformBAfterLDG_,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_>
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
std::enable_if_t<isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
{
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
using TransformBAfterLDG = TransformBAfterLDG_;
using TransformBAfterLDS = TransformBAfterLDS_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
typename SmemIteratorScale::Element, LayoutScale, 32, QuantOp>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, "");
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using WarpFragmentScale = typename Dequantizer::FragmentScale;
using WarpFragmentZero = typename Dequantizer::FragmentZero;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int const group_size, ///< The group size for quantization
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
void copy_scales_and_advance(IteratorScale& iterator_scale)
{
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Element,
typename FragmentScale::Element, FragmentScale::kElements>;
FragmentScale tb_frag_scales;
FragmentScale tb_frag_zeros;
tb_frag_scales.clear();
tb_frag_zeros.clear();
TransformScale transformScale;
using FragmentElement = typename FragmentScale::Element;
auto gmem_scale_ptr = iterator_scale.get_scale();
auto gmem_zero_ptr = iterator_scale.get_zero();
arch::global_load<FragmentScale, sizeof(FragmentScale)>(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid());
if (gmem_zero_ptr != nullptr)
{
arch::global_load<FragmentScale, sizeof(FragmentScale)>(
tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid());
}
typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales);
typename TransformScale::result_type tb_frag_zeros_fp16;
if (gmem_zero_ptr != nullptr)
tb_frag_zeros_fp16 = transformScale(tb_frag_zeros);
auto frag_scale_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_scales_fp16);
auto frag_zero_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_zeros_fp16);
auto smem_scale_ptr = this->smem_iterator_scale_.get_scale();
auto smem_zero_ptr = this->smem_iterator_scale_.get_zero();
if (iterator_scale.valid())
{
auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr);
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_scale_ptr_fp16);
if (gmem_zero_ptr != nullptr)
{
smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr);
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_zero_ptr_fp16);
}
}
if (iterator_scale.group_size_ == 64)
{
iterator_scale.add_tile_offset({1, 0});
}
else if (iterator_scale.group_size_ == 128)
{
if constexpr (Shape::kK == 128)
{
iterator_scale.add_tile_offset({1, 0});
}
else if constexpr (Shape::kK == 64)
{
if (iterator_scale.row_groupsize64_ & 0x1)
{
iterator_scale.add_tile_offset({1, 0});
}
}
else
{
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
}
}
iterator_scale.row_groupsize64_++;
this->smem_iterator_scale_.add_tile_offset({1, 0});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
FragmentC const& src_accum)
{ ///< source accumulator tile
//
// Prologue
//
TransformBAfterLDG ldg_converter;
TransformBAfterLDS lds_converter;
using TransformA
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
TransformA transformA;
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
tb_frag_A.clear();
tb_frag_B.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
copy_scales_and_advance(iterator_scale);
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
WarpFragmentScale warp_frag_scales;
WarpFragmentZero warp_frag_zero;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
warp_dequantizer_.add_pointer_offset(Shape::kN);
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
iterator_scale.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations)
{
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1)
{
// Write fragments to shared memory
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1)
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
}
else
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
if (warp_mma_k == 0)
{
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
copy_scales_and_advance(iterator_scale);
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
iterator_scale.clear_mask(gemm_k_iterations <= 2);
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero);
run_warp_mma(
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
}
// Load the scales needed for the next tile iteration
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
// Update internal pointer to the set of scales in shared memory
warp_dequantizer_.add_pointer_offset(Shape::kN);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
#include "cutlass_extensions/gemm_configs.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Iterators over scales in global memory
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Layout of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Converter for B matrix applied immediately after the LDG (before STS)
typename TransformBAfterLDG_,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_>
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
std::enable_if_t<!isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
{
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
using TransformBAfterLDG = TransformBAfterLDG_;
using TransformBAfterLDS = TransformBAfterLDS_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
typename SmemIteratorScale::Fragment::Element, LayoutScale, 32, QuantOp>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
///< argument is not added, it does not affect compilation for sm>=80.
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
FragmentC const& src_accum)
{ ///< source accumulator tile
//
// Prologue
//
TransformBAfterLDG ldg_converter;
TransformBAfterLDS lds_converter;
using TransformA
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
typename FragmentScale::Element, FragmentScale::kElements>;
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
TransformA transformA;
TransformScale transformScale;
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
FragmentScale tb_frag_scales;
using WarpFragmentScale = typename Dequantizer::FragmentScale;
WarpFragmentScale warp_frag_scales;
tb_frag_A.clear();
tb_frag_B.clear();
tb_frag_scales.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
iterator_scale.load(tb_frag_scales);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
__syncthreads();
warp_dequantizer_.load(warp_frag_scales);
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations)
{
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1)
{
// Write fragments to shared memory
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1)
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
}
else
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
if (warp_mma_k == 0)
{
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
run_warp_mma(
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Default warp-level GEMM operators selected by data type, size, and layouts of operands.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
#include "cutlass/gemm/warp/mma_tensor_op.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
namespace cutlass
{
namespace gemm
{
namespace warp
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for m-by-n-by-kgroup
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A elements,
typename ElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA,
/// Data type of B elements
typename ElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Number of partitions along K dimension
int PartitionsK,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
struct DefaultMmaTensorOp<WarpShape_, InstructionShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
arch::OpMultiplyAddDequantizeInterleavedBToA, PartitionsK, AccumulatorsInRowMajor>
{
private:
// Shape for computing the FP16s
using ComputeInstructionShape = InstructionShape_;
// Chosen so we get K=16 for int8 and K=32 for int4.
static constexpr int LoadInstructionK = 128 / sizeof_bits<ElementB>::value;
// Shape for loading the narrow data type from shared memory
using LoadInstructionShape = GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;
public:
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<InstructionShape_, 32, ElementA, cutlass::layout::RowMajor, ElementA,
cutlass::layout::ColumnMajor, ElementC, cutlass::layout::RowMajor, arch::OpMultiplyAdd>,
cutlass::MatrixShape<1, 1>>;
// Define the warp-level tensor op
using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16<WarpShape_, ElementA, LayoutA, ElementB, LayoutB,
ElementC, LayoutC, Policy, LoadInstructionShape, PartitionsK, AccumulatorsInRowMajor>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing warp-level matrix multiply-accumulate operations targeting
Tensor Cores.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/platform/platform.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm80.h"
#include "cutlass/arch/mma_sm89.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma.h"
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace warp
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename ElementA_,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Data type of B elements
typename ElementB_,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
typename Policy_,
/// Instruction shape to override shared memory iterators with
typename SharedMemoryInstructionShape_,
/// Number of partitions along K dimension
int PartitionsK_ = 1,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false,
/// Used for partial specialization
typename Enable = bool>
class MmaTensorOpComputeBWithF16
{
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using Shape = Shape_;
/// Data type of multiplicand A
using ElementA = ElementA_;
/// Layout of multiplicand A
using LayoutA = LayoutA_;
/// Data type of multiplicand B
using ElementB = ElementB_;
/// Layout of multiplicand B
using LayoutB = LayoutB_;
/// Data type of accumulator matrix C
using ElementC = ElementC_;
/// Layout of accumulator matrix C
using LayoutC = LayoutC_;
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Underlying matrix multiply operator (concept: arch::Mma)
using ArchMmaOperator = typename Policy::Operator;
/// Indicates math operator
using MathOperator = typename ArchMmaOperator::Operator;
/// Architecture tag from underlying instruction
using ArchTag = typename ArchMmaOperator::ArchTag;
static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
&& ArchTag::kMinComputeCapability >= 80)
|| (platform::is_same<typename ArchMmaOperator::ElementA, float_e4m3_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, float_e4m3_t>::value
&& ArchTag::kMinComputeCapability >= 89),
"MmaTensorOpCvtBToA only supports underlying HMMA/QMMA");
static_assert(platform::is_same<ElementA, half_t>::value
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80)
|| (platform::is_same<ElementA, float_e4m3_t>::value && ArchTag::kMinComputeCapability >= 89),
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada");
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
/// Shape of underlying instruction
using InstructionShape = typename ArchMmaOperator::Shape;
/// Instruction shape to override shared memory iterators with
using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
static_assert(
SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load");
static_assert(
SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load");
static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
/// Complex transform on B operand
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Number of threads participating in warp-level matrix product
static int const kThreadCount = 32;
/// Number of partitions along K dimension
static int const kPartitionsK = PartitionsK_;
public:
/// Iterates over the A operand in memory
using IteratorA
= MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
MatrixShape<InstructionShape::kM, InstructionShape::kK>, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for A tile
using FragmentA = typename IteratorA::Fragment;
/// Storage for transformed A tile
using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
/// Iterates over the B operand in memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB,
LayoutB, MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>, Policy::OpDelta::kRow,
kThreadCount, kPartitionsK>;
/// Storage for B tile
using FragmentB = typename IteratorB::Fragment;
/// Storage for transformed B tile
using TransformedFragmentB = Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;
/// Iterates over the C operand in memory
using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
/// Storage for C tile
using FragmentC = typename IteratorC::Fragment;
/// Number of mma operations performed
using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
public:
/// Underlying matrix multiply operator (concept: arch::Mma)
ArchMmaOperator mma;
public:
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
MmaTensorOpComputeBWithF16() {}
/// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C,
int const warp_tileB_k_offset) const
{
using MmaOperandA = typename ArchMmaOperator::FragmentA;
using MmaOperandB = typename ArchMmaOperator::FragmentB;
using MmaOperandC = typename ArchMmaOperator::FragmentC;
static_assert(
TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn,
"Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of "
"B");
D = C;
MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
// Serpentine visitation order maximizing reuse of Rb
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n)
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m)
{
int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n;
if (AccumulatorsInRowMajor)
{ // matrix B is reordered
mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB],
ptr_D[n + m_serpentine * MmaIterations::kColumn]);
}
else
{
mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB],
ptr_D[m_serpentine + n * MmaIterations::kRow]);
}
}
}
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// Serpentine visitation order maximizing reuse of Ra
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n)
{
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine;
if (AccumulatorsInRowMajor)
{ // matrix B is reordered
mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB],
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
}
else
{
mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB],
ptr_D[m + n_serpentine * MmaIterations::kRow]);
}
}
}
#else
assert(0);
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/functional.h"
#include "cutlass/platform/platform.h"
#include "cutlass_extensions/weight_only_quant_op.h"
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace warp
{
////////////////////////////////////////////////////////////////////////////////
template <
/// Matrix multiply operator
typename MmaOperator_,
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Operand identity
Operand Operand,
/// Data type of Scale elements
typename Element_,
/// Layout of operand
typename Layout_,
/// Number of threads participating in one matrix operation
int Threads,
///
WeightOnlyQuantOp QuantOp_,
///
typename Enable = void>
class MmaTensorOpDequantizer;
////////////////////////////////////////////////////////////////////////////////
// Bfloat specialization for Ampere
template <
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_,
///
WeightOnlyQuantOp QuantOp_>
class MmaTensorOpDequantizer<MmaOperator_, Shape_, Operand::kB, bfloat16_t, layout::RowMajor, 32, QuantOp_,
typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 80
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
{
public:
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
// This is the ratio of the load instruction vs the compute instruction.
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
/// Type of the scales
using ElementScale = bfloat16_t;
/// Fragment to hold B data before Mma
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static constexpr int kColsPerMmaPerThread = 1;
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
using FragmentZero = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
/// Warp mma shape
using Shape = Shape_;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<ElementScale, Layout>;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
{
int const warp_offset = warp_idx_n * Shape::kN;
int const quad = lane_idx / 4;
int const thread_offset = warp_offset + quad;
pointer_scale_ = smem_scales.data() + thread_offset;
if constexpr (hasZero(QuantOp))
{
pointer_zero_ = smem_zeros.data() + thread_offset;
}
}
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
{
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag)
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
}
}
CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
== FragmentDequantizedOperand::kElements,
"");
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
__nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
{
operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
}
}
#else
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
// numerous conversion instructions in GEMM main loop.
arch::device_breakpoint();
#endif
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag, FragmentScale& zero_frag)
{
if constexpr (hasZero(QuantOp))
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN];
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
}
}
}
CUTLASS_DEVICE
void dequantize(
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
== FragmentDequantizedOperand::kElements,
"");
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
__nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag);
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
__nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]);
__nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
if constexpr (hasZero(QuantOp))
{
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
{
operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2);
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
{
operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
}
}
}
#else
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
// numerous conversion instructions in GEMM main loop.
arch::device_breakpoint();
#endif
}
// Adds a pointer offset in units of elements.
CUTLASS_DEVICE
void add_pointer_offset(int64_t const& offset)
{
static_assert(sizeof(ElementScale) > 1, "");
pointer_scale_ += offset;
pointer_zero_ += offset;
}
private:
ElementScale const* pointer_scale_;
ElementScale const* pointer_zero_;
};
////////////////////////////////////////////////////////////////////////////////
// Specialization for Turing & Ampere
template <
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_,
///
WeightOnlyQuantOp QuantOp_>
class MmaTensorOpDequantizer<MmaOperator_, Shape_, Operand::kB, half_t, layout::RowMajor, 32, QuantOp_,
typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 75
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
{
public:
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
// This is the ratio of the load instruction vs the compute instruction.
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
/// Type of the scales
using ElementScale = half_t;
/// Fragment to hold B data before Mma
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static constexpr int kColsPerMmaPerThread = 1;
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
using FragmentZero = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
/// Warp mma shape
using Shape = Shape_;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<ElementScale, Layout>;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
{
int const warp_offset = warp_idx_n * Shape::kN;
int const quad = lane_idx / 4;
int const thread_offset = warp_offset + quad;
pointer_scale_ = smem_scales.data() + thread_offset;
if constexpr (hasZero(QuantOp))
{
pointer_zero_ = smem_zeros.data() + thread_offset;
}
}
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
{
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag)
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
}
}
CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
{
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
== FragmentDequantizedOperand::kElements,
"");
multiplies<ExpandedMmaOperandB> mul_op;
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
}
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag, FragmentScale& zero_frag)
{
if constexpr (hasZero(QuantOp))
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN];
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
}
}
}
CUTLASS_DEVICE
void dequantize(
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
{
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
== FragmentDequantizedOperand::kElements,
"");
multiplies<ExpandedMmaOperandB> mul_op;
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
if constexpr (hasZero(QuantOp))
{
plus<ExpandedMmaOperandB> plus_op;
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
operand_frag_ptr[mma_n_iter]
= plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]);
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
}
}
}
// Adds a pointer offset in units of elements.
CUTLASS_DEVICE
void add_pointer_offset(int64_t const& offset)
{
static_assert(sizeof(ElementScale) > 1, "");
pointer_scale_ += offset;
pointer_zero_ += offset;
}
private:
ElementScale const* pointer_scale_;
ElementScale const* pointer_zero_;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cassert>
#include <iostream>
#include <sstream>
#include <string>
namespace tensorrt_llm
{
namespace cutlass_extensions
{
// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape
// in the kernel layout details when doing weight only quantization.
enum class CutlassTileConfig
{
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
// SiMT config
CtaShape128x128x8_WarpShape64x64x8,
// TensorCore configs CTA_N = 128, CTA_K = 64
// Warp configs for M=16
CtaShape16x128x64_WarpShape16x32x64,
// Warp configs for M=32
CtaShape32x128x64_WarpShape32x32x64,
// Warp configs for M=64
CtaShape64x128x64_WarpShape32x64x64,
CtaShape64x64x128_WarpShape32x64x64,
CtaShape64x128x64_WarpShape64x32x64,
// Warp configs for M=128
CtaShape128x64x64_WarpShape64x32x64,
CtaShape128x128x64_WarpShape64x32x64,
CtaShape128x128x64_WarpShape64x64x64,
CtaShape128x128x64_WarpShape128x32x64,
CtaShape128x256x64_WarpShape64x64x64,
// Warp configs for M=256
CtaShape256x128x64_WarpShape64x64x64,
// TensorCore config CTA_N = 64, CTA_K = 128
CtaShape128x64x128_WarpShape64x32x128,
// TensorCore config CTA_N = 256, CTA_K = 64
CtaShape16x256x64_WarpShape16x64x64,
// TensorCore config CTA_N = 256, CTA_K = 128
CtaShape16x256x128_WarpShape16x64x128
};
enum class SplitKStyle
{
NO_SPLIT_K,
SPLIT_K_SERIAL,
STREAM_K, // Sm80+
// SPLIT_K_PARALLEL // Not supported yet
};
enum class CutlassTileConfigSM90
{
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
// CTA configs for M=64
CtaShape64x16x128B,
CtaShape64x32x128B,
CtaShape64x64x128B,
CtaShape64x128x128B,
CtaShape64x256x128B,
// CTA configs for M=128
CtaShape128x16x128B,
CtaShape128x32x128B,
CtaShape128x64x128B,
CtaShape128x128x128B,
CtaShape128x256x128B,
// CTA configs for M=128
CtaShape256x128x128B,
};
enum class MainloopScheduleType
{
AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this
// defaults to the "legacy" main loop schedule.
};
enum class EpilogueScheduleType
{
AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For
// architectures older than hopper, the epilogue is always performed by the same thread block as the main loop.
};
enum class ClusterShape
{
ClusterShape_1x1x1,
ClusterShape_2x1x1,
ClusterShape_1x2x1,
ClusterShape_2x2x1,
ClusterShape_1x8x1,
ClusterShape_8x1x1
};
struct CutlassGemmConfig
{
enum CandidateConfigTypeParam : int
{
NONE = 0,
WEIGHT_ONLY = 1u << 0,
SIMT_ONLY = 1u << 1,
INT8_ONLY = 1u << 2,
HOPPER = 1u << 3,
GROUPED_GEMM = 1u << 4,
FP8_ONLY = 1u << 5,
};
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
int split_k_factor = -1;
int stages = -1;
// config options for sm90
CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic;
MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO;
EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO;
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
bool is_sm90 = false;
CutlassGemmConfig() {}
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
: tile_config(tile_config)
, split_k_style(split_k_style)
, split_k_factor(split_k_factor)
, stages(stages)
, is_sm90(false)
{
}
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
: tile_config_sm90(tile_config_sm90)
, mainloop_schedule(mainloop_schedule)
, epilogue_schedule(epilogue_schedule)
, cluster_shape(cluster_shape)
, is_sm90(true)
{
}
std::string toString() const
{
std::stringstream tactic;
tactic << "Cutlass GEMM Tactic";
if (tile_config_sm90 != tensorrt_llm::cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
{
assert(is_sm90 && "Invalid cutlass GEMM config");
tactic << "\n\tstyle=TMA"
<< "\n\ttile shape ID: " << (int) tile_config_sm90 << "\n\tcluster shape ID: " << (int) cluster_shape
<< "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule;
}
else if (tile_config != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
{
assert(!is_sm90 && "Invalid cutlass GEMM config");
tactic << "\n\tstyle=compatible"
<< "\n\ttile shape ID: " << (int) tile_config << "\n\tstages: " << (int) stages
<< "\n\tsplit k: " << (int) split_k_factor;
}
else
{
tactic << "\n\tundefined";
}
tactic << "\n";
return tactic.str();
}
};
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config)
{
// clang-format off
if (config.is_sm90)
{
out << "tile_config_sm90_enum: " << int(config.tile_config_sm90)
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
<< ", cluster_shape_enum: " << int(config.cluster_shape);
}
else
{
out << "tile_config_enum: " << int(config.tile_config)
<< ", split_k_style_enum: " << int(config.split_k_style)
<< ", split_k_factor: " << config.split_k_factor
<< ", stages: " << config.stages;
}
// clang-format on
return out;
}
} // namespace cutlass_extensions
} // namespace tensorrt_llm
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/array.h"
#include "cutlass/half.h"
#include "cutlass/numeric_types.h"
namespace cutlass
{
// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
// This converter will uninterleave the data and subtract the bias while converting to the result type.
template <typename T, typename S, int N>
struct FastInterleavedAndBiasedNumericArrayConverter
{
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4>
{
using result_type = Array<half_t, 4>;
using source_type = Array<uint8_t, 4>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
result_type result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
// Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <int N>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N>
{
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using result_type = Array<half_t, N>;
using source_type = Array<uint8_t, N>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / VEC_WIDTH; ++i)
{
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, 4>
{
using result_type = Array<bfloat16_t, 4>;
using source_type = Array<uint8_t, 4>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
result_type result;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&result);
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
// Construct FP32s, bfloat does not have enough mantissa for IADD trick
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
// Subtract out fp32_base + 128 to make the unsigned integer signed.
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < 4; ++ii)
{
fp32_intermediates[ii] -= 8388736.f;
}
// Truncate the fp32 representation and pack up as bfloat16s.
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < 2; ++ii)
{
bf16_result_ptr[ii]
= __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
}
#else
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
result.clear(); // Suppress compiler warning
arch::device_breakpoint();
#endif
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <int N>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, N>
{
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using result_type = Array<bfloat16_t, N>;
using source_type = Array<uint8_t, N>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / VEC_WIDTH; ++i)
{
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, 8>
{
using result_type = Array<half_t, 8>;
using source_type = Array<uint4b_t, 8>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
result_type result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
static constexpr uint32_t NEG_72 = 0xd480d480;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <int N>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, N>
{
static constexpr int VEC_WIDTH = 8;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
using result_type = Array<half_t, N>;
using source_type = Array<uint4b_t, N>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / VEC_WIDTH; ++i)
{
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, 8>
{
using result_type = Array<bfloat16_t, 8>;
using source_type = Array<uint4b_t, 8>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
result_type result;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const source_i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
// We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop.
// No shift needed for first item.
uint32_t i4s = source_i4s;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
CUTLASS_PRAGMA_UNROLL
for (int ii = 1; ii < result_type::kElements / 2; ++ii)
{
i4s >>= sizeof_bits<typename source_type::Element>::value;
// (i4s & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[ii])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
}
// This is the BF16 {-136, -136} represented as an integer.
static constexpr uint32_t BF16_BIAS = 0xC308C308;
static constexpr uint32_t BF16_ONE = 0x3F803F80;
// Finally, we construct the output numbers.
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < result_type::kElements / 2; ++ii)
{
// Since this section is for Ampere+, we use bf16 fma to do the bias subtraction
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
}
#else
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
arch::device_breakpoint();
result.clear(); // Suppress compiler warning.
#endif
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <int N>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N>
{
static constexpr int VEC_WIDTH = 8;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
using result_type = Array<bfloat16_t, N>;
using source_type = Array<uint4b_t, N>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / VEC_WIDTH; ++i)
{
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines new layouts needed for MoE
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/pitch_linear_coord.h"
namespace cutlass
{
namespace layout
{
template <int RowsPerTile, int ColumnsInterleaved>
struct ColumnMajorTileInterleave
{
static constexpr int kRowsPerTile = RowsPerTile;
static constexpr int kColumnsInterleaved = ColumnsInterleaved;
};
template <class T>
struct IsColumnMajorTileInterleave
{
static constexpr bool value = false;
};
template <int U, int V>
struct IsColumnMajorTileInterleave<ColumnMajorTileInterleave<U, V>>
{
static constexpr bool value = true;
};
} // namespace layout
} // namespace cutlass
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM
quantization.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/coord.h"
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h"
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace transform
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
template <typename Shape, typename Element, typename Layout, int AdvanceRank, int Alignment>
class FineGrainedScaleZeroIterator;
template <typename Shape_, typename Element_, int Alignment_>
class FineGrainedScaleZeroIterator<Shape_, Element_, layout::RowMajor, 0, Alignment_>
{
public:
using Shape = Shape_;
using Element = Element_;
using Layout = layout::RowMajor;
static int const kAdvanceRank = 0;
static int const kAlignment = Alignment_;
static int const kAccessesPerVector = 1;
/// Row index of scales corresponding to the groupsize of 64
int row_groupsize64_;
int group_size_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
using AccessType = AlignedArray<Element, kAlignment>;
using Fragment = cutlass::Array<Element, kAlignment>;
// For compatibility with existing iterator interface
struct Params
{
LongIndex stride_ = 0;
/// amount (in byte) to increment pointer from first access of current tile
/// to first access of next tile
LongIndex inc_advance_ = 0;
// Default ctor
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout)
: stride_(layout.stride(0))
{
inc_advance_ = Shape::kRow * stride_ * sizeof_bits<Element>::value / 8;
}
};
private:
/// Internal pointer type permits fast address arithmetic
using BytePointer = char*;
private:
//
// Data members
//
/// Parameters object with precomputed internal state
Params const params_;
/// Internal pointer to first access of tile
BytePointer pointer_scale_;
BytePointer pointer_zero_;
bool is_valid_ = false;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_DEVICE
FineGrainedScaleZeroIterator(
///< Precomputed parameters object
Params const& params,
///< Pointer to start of scale tensor
Pointer pointer_scale,
///< Pointer to start of zero tensor
Pointer pointer_zero,
///< Extent of the scale and bias
TensorCoord extent,
///< ID of each participating thread
int thread_id,
///< Initial offset of threadblock
TensorCoord const& threadblock_offset,
///< Group size
int group_size)
: params_(params)
, pointer_scale_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer_scale)))
, pointer_zero_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer_zero)))
{
row_groupsize64_ = threadblock_offset.row();
group_size_ = group_size;
const LongIndex tb_row_byte_offset
= threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits<Element>::value / 8;
const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits<Element>::value / 8;
pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset);
if (pointer_zero_ != nullptr)
{
pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset);
}
static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment;
int const thread_row = thread_id / THREADS_PER_ROW;
int const thread_col = thread_id % THREADS_PER_ROW;
const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits<Element>::value / 8;
const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits<Element>::value / 8;
pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset);
if (pointer_zero_ != nullptr)
{
pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset);
}
// For the rows, we must check that we are within the extent AND the tile to avoid extra reads on
// a given iteration. The same threads will be responsible for issues reads since the number of scales
// read in a given iteration is a constant. Therefore, we should never have to update is_valid_
// outside of the constructor.
int const global_row = threadblock_offset.row() + thread_row;
int const global_col = threadblock_offset.column() + thread_col * kAlignment;
bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow;
bool const col_in_bounds = global_col < extent.column();
is_valid_ = row_in_bounds && col_in_bounds;
}
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object
Pointer pointer_scale, ///< Pointer to start of scale tensor
Pointer pointer_zero, ///< Pointer to start of zero tensor
TensorCoord extent, ///< Extent of tensor
int thread_id, ///< ID of each participating thread
int group_size)
: FineGrainedScaleZeroIterator(
params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size)
{
}
CUTLASS_DEVICE
void add_tile_offset(TensorCoord const& tile_offset)
{
const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_;
const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits<Element>::value / 8;
pointer_scale_ += row_byte_offset + col_byte_offset;
if (pointer_zero_ != nullptr)
{
pointer_zero_ += row_byte_offset + col_byte_offset;
}
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE void clear_mask(bool enable = true)
{
is_valid_ &= (!enable);
}
/// Returns whether access is valid or not
CUTLASS_HOST_DEVICE
bool valid() const
{
return is_valid_;
}
/// Returns a scale pointer
CUTLASS_HOST_DEVICE
AccessType* get_scale() const
{
return reinterpret_cast<AccessType*>(pointer_scale_);
}
/// Returns a zero pointer
CUTLASS_HOST_DEVICE
AccessType* get_zero() const
{
return reinterpret_cast<AccessType*>(pointer_zero_);
}
};
} // namespace threadblock
} // namespace transform
} // namespace cutlass
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cute/util/print.hpp"
using namespace cute;
/// Function object that applies an index to its argument
template <class Iter>
struct IndexedGather
{
CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {})
: indices_(indices)
{
}
template <typename I>
CUTE_HOST_DEVICE constexpr auto operator()(I i) const
{
return indices_[i];
}
CUTE_HOST_DEVICE friend void print(IndexedGather const& s)
{
cute::print("Indexed{");
print(s.indices_);
print("}");
}
Iter indices_;
};
/// Custom stride object that applies a function followed by a stride
template <class Func, class Stride>
struct CustomStride
{
CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, Stride const& stride)
: func_(func)
, stride_(stride)
{
}
template <class I>
CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s)
{
return s.func_(i) * s.stride_;
}
template <class I>
CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i)
{
return s.func_(i) * s.stride_;
}
CUTE_HOST_DEVICE friend void print(CustomStride const& s)
{
cute::print("Custom{");
print(s.func_);
cute::print(",");
print(s.stride_);
cute::print("}");
}
template <class Div>
CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div)
{
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
}
// Circumvent the requirement on make_layout that shape and stride are integral
template <class Shape>
CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride)
{
return Layout<Shape, CustomStride>(shape, stride);
}
Func func_;
Stride stride_;
};
template <class Stride, class Func>
CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func)
{
// Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather stride
auto idx = find_if(stride, [](auto x) { return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; });
constexpr int I = decltype(idx)::value;
return make_layout(
repeat_like(stride, _1{}), replace<I>(stride, CustomStride{static_cast<Func&&>(func), get<I>(stride)}));
}
/// Helper function to optionally create a gather tensor
template <class Iterator, class Shape, class Stride, class Func>
CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func)
{
Layout matrix_layout = make_identity_layout(shape);
auto offset = as_arithmetic_tuple(repeat_like(shape, _0{}));
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
}
namespace cute
{
template <int N, int I, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride)
{
if constexpr (is_tuple<Shape>::value)
{
return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast<N, I>(s, d); });
}
else if constexpr (is_scaled_basis<Stride>::value)
{
if constexpr (Stride::mode() == I)
{
return make_layout(shape_div(shape, Int<N>{}), shape_div(stride, Int<N>{}));
}
else
{
return make_layout(shape, stride);
}
}
else
{
return upcast<N>(shape, stride);
}
CUTE_GCC_UNREACHABLE;
}
template <int N, class OuterShape, class OuterStride, class Offset, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr auto upcast(
ComposedLayout<Layout<OuterShape, OuterStride>, Offset, Layout<Shape, Stride>> const& layout)
{
// Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset
auto idx = find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; });
constexpr int I = decltype(idx)::value;
// Upcast the outer layout (works as expected)
auto outer = upcast<N>(layout.layout_a());
// Upcast the accumulated offset along stride-1 mode
auto offset = as_arithmetic_tuple(replace<I>(layout.offset(), upcast<N>(get<I>(layout.offset()))));
// Upcast the inner layout's shape along stride-1 mode
auto inner = upcast<N, I>(layout.layout_b().shape(), layout.layout_b().stride());
return composition(outer, offset, inner);
}
} // namespace cute
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
*/
#pragma once
namespace cutlass
{
enum class WeightOnlyQuantOp
{
UNDEFINED,
PER_COLUMN_SCALE_ONLY,
FINEGRAINED_SCALE_ONLY,
FINEGRAINED_SCALE_AND_ZEROS
};
constexpr bool isFinegrained(WeightOnlyQuantOp op)
{
return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
}
constexpr bool hasZero(WeightOnlyQuantOp op)
{
return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
}
} // namespace cutlass
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
namespace tensorrt_llm::kernels::cutlass_kernels
{
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
typename EpilogueTag>
void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B,
ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert,
int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream,
int* kernel_occupancy);
}
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include <cutlass_extensions/epilogue_helpers.h>
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh>
#include <tensorrt_llm/common/cudaUtils.h>
namespace tensorrt_llm::kernels::cutlass_kernels
{
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
typename EpilogueTag>
void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B,
ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert,
int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream,
int* kernel_occupancy)
{
constexpr auto activation_type = fused_moe::EpilogueRouting<EpilogueTag>(true);
using GemmType = fused_moe::Fused_Moe_Kernel_sm80<ElementType_, CutlassWeightType_, ElementType_, MaxTileM_, TileN_,
TileK_, Stages_, activation_type>;
// make sure GPU has enough resources..
if (kernel_occupancy != nullptr)
{
constexpr int smem_size = GemmType::kSmemSize;
if (smem_size > (48 << 10))
{
cudaFuncAttributes attr{};
int device = 0;
int max_smem_per_block = 0;
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
tensorrt_llm::common::check_cuda_error(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global<GemmType>));
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
{
// This should mean that
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize,
// smem_size) wouldn't work. In that case, we return an occupancy of 0. This will cause the
// heuristic to ignore this configuration.
*kernel_occupancy = 0;
return;
}
}
int max_active_blocks = -1;
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, fused_moe::run_global<GemmType>, GemmType::kThreadCount, smem_size));
*kernel_occupancy = max_active_blocks;
return;
}
int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks<GemmType>());
int const threadblock_count = multi_processor_count * occupancy;
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel");
using Arguments = typename GemmType::Arguments;
Arguments args{{const_cast<ElementType_*>(A), const_cast<CutlassWeightType_*>(B), const_cast<ElementType_*>(biases),
reinterpret_cast<ElementType_*>(C), total_tokens_including_expert, static_cast<int>(gemm_n),
static_cast<int>(gemm_k), num_experts, bias_is_broadcast},
num_experts, threadblock_count};
auto params = GemmType::to_underlying_arguments(args);
if (GemmType::kSmemSize >= (48 << 10))
{
cudaError_t result = cudaFuncSetAttribute(
fused_moe::run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize);
TLLM_CHECK_WITH_INFO(result == cudaSuccess,
"Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel");
}
dim3 grid(params.threadblock_count, 1, 1);
dim3 block(GemmType::kThreadCount);
fused_moe::run_global<GemmType><<<grid, block, GemmType::kSmemSize, stream>>>(params);
auto result = cudaGetLastError();
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result));
}
} // namespace tensorrt_llm::kernels::cutlass_kernels
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include <cuda_runtime_api.h>
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
// Keep in sync with the signature generated by generate_kernels.py
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag,
HopperGroupedGemmInput::EpilogueFusion FUSION, typename TileShape, typename ClusterShape, bool BIAS>
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
int multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size);
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <sstream>
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion;
// Hopper helper class for defining all the cutlass helper types
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, typename TileShape,
typename ClusterShape, bool BIAS, EpilogueFusion FUSION>
struct HopperGroupedGemmInfo
{
using Arch = cutlass::arch::Sm90;
// TODO Update once mixed input support is added
static_assert(cutlass::platform::is_same<T, WeightType>::value,
"CUTLASS does not currently have specialised SM90 support for quantized operations");
#ifdef ENABLE_FP8
constexpr static bool IsFP8
= cutlass::platform::is_same<T, __nv_fp8_e4m3>::value || cutlass::platform::is_same<T, __nv_fp8_e5m2>::value;
#else
constexpr static bool IsFP8 = false;
#endif
#ifdef ENABLE_BF16
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|| cutlass::platform::is_same<T, float>::value || IsFP8,
"Specialized for bfloat16, half, float, fp8");
#else
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value || IsFP8,
"Specialized for half, float, fp8");
#endif
static_assert(cutlass::platform::is_same<T, WeightType>::value
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::float_e4m3_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::float_e5m2_t>::value,
"Unexpected quantization type");
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter<WeightType>::type;
// For legacy reasons we convert unsigned 8-bit to signed
using CutlassWeightTypeMaybeUint8
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint4, cutlass::uint4b_t>, cutlass::int4b_t,
CutlassWeightTypeMaybeUint4>;
using CutlassWeightType
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint8, uint8_t>, int8_t, CutlassWeightTypeMaybeUint8>;
using ElementA = ElementType;
using ElementB = CutlassWeightType;
using ElementD = typename TllmToCutlassTypeAdapter<HopperGroupedGemmInput::OutputTypeAdaptor_t<OutputType>>::type;
using ElementFinalOutput = typename TllmToCutlassTypeAdapter<OutputType>::type;
// using ElementC = std::conditional_t<BIAS, ElementType, void>;
// using ElementCNoVoid = std::conditional_t<BIAS, ElementType, ElementD>;
using ElementC = void;
using ElementCNoVoid = ElementD;
using ElementAccumulator = float;
using ElementBias = ElementFinalOutput;
using ElementRouterScales = float;
// A matrix configuration - this is transposed and swapped with B
using LayoutA = HopperGroupedGemmInput::LayoutA;
constexpr static int AlignmentA
= 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units
// of elements (up to 16 bytes)
// B matrix configuration - this is transposed and swapped with A
using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand
constexpr static int AlignmentB
= 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units
// of elements (up to 16 bytes)
// C matrix configuration
using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand
using StrideC = HopperGroupedGemmInput::StrideC;
// Note we use ElementType here deliberately, so we don't break when BIAS is disabled
constexpr static int AlignmentC
= 128 / cutlass::sizeof_bits<ElementType>::value; // Memory access granularity/alignment of C matrix in units
// of elements (up to 16 bytes)
// D matrix configuration
using LayoutD = HopperGroupedGemmInput::DefaultEpilogue::LayoutD;
using StrideD = HopperGroupedGemmInput::DefaultEpilogue::StrideD;
constexpr static int AlignmentD
= 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix
// in units of elements (up to 16 bytes)
static_assert(cutlass::platform::is_same<EpilogueTag, tensorrt_llm::cutlass_extensions::EpilogueOpDefault>::value,
"Hopper Grouped GEMM specialisation doesn't support fused activation");
using EpilogueOp
= cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
// TODO Add mode for fused activation once CUTLASS adds support
// using EpilogueSchedule = cutlass::platform::conditional_t<
// cutlass::platform::is_same<EpilogueOp, EpilogueOpDefault>::value,
// cutlass::epilogue::PtrArrayNoSmemWarpSpecialized,
// cutlass::epilogue::?????????????????? /// <<<<<< what supports activations
// >;
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;
// Epilogue For Default Finalize
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder< //
Arch, cutlass::arch::OpClassTensorOp, //
TileShape, ClusterShape, //
cutlass::epilogue::collective::EpilogueTileAuto, //
ElementAccumulator, ElementAccumulator, //
ElementC, LayoutC*, AlignmentC, //
ElementD, LayoutD*, AlignmentD, //
EpilogueSchedule>::CollectiveOp;
// Epilogue For Fused Finalize
using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< //
TileShape, //
ElementCNoVoid, StrideC*, //
ElementFinalOutput, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, //
ElementAccumulator, //
ElementAccumulator, //
ElementBias, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, //
ElementRouterScales, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales //
>::CollectiveOp;
using CollectiveEpilogue
= std::conditional_t<FUSION == EpilogueFusion::FINALIZE, CollectiveEpilogueFinalize, CollectiveEpilogueDefault>;
using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>;
using KernelSchedule
= std::conditional_t<IsFP8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< //
Arch, cutlass::arch::OpClassTensorOp, //
CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here
ElementType, LayoutA*, AlignmentA, //
ElementAccumulator, //
TileShape, ClusterShape, //
StageCountAutoCarveout, KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<HopperGroupedGemmInput::ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
// Hopper specialised version
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
typename TileShape, typename ClusterShape, bool BIAS>
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size)
{
#ifdef COMPILE_HOPPER_TMA_GEMMS
using namespace cute;
if constexpr (!should_filter_sm90_gemm_problem_shape_v<TileShape, ClusterShape, T>)
{
using GemmInfo
= HopperGroupedGemmInfo<T, WeightType, OutputType, EpilogueTag, TileShape, ClusterShape, BIAS, FUSION>;
using ElementAccumulator = typename GemmInfo::ElementAccumulator;
using ElementA = typename GemmInfo::ElementA;
using ElementB = typename GemmInfo::ElementB;
using ElementC = typename GemmInfo::ElementC;
using ElementCNoVoid = typename GemmInfo::ElementCNoVoid;
using ElementD = typename GemmInfo::ElementD;
using CollectiveMainloop = typename GemmInfo::CollectiveMainloop;
using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue;
using GemmKernel = typename GemmInfo::GemmKernel;
using GemmGrouped = typename GemmInfo::GemmGrouped;
if (kernel_occupancy != nullptr)
{
*kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel, true>();
return;
}
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = multi_processor_count;
GemmGrouped gemm;
if (workspace_size != nullptr)
{
// Make a mock problem shape with just the minimal information actually required to get the workspace size
// This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to
// catch future cutlass updates causing silent breakages, but that is not fool proof.
// The alternative is to wait until we have data and then dynamically allocate the workspace
typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr};
typename GemmGrouped::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info};
*workspace_size = gemm.get_workspace_size(args);
return;
}
using MainloopArguments = typename CollectiveMainloop::Arguments;
TLLM_CHECK(hopper_input.stride_a);
TLLM_CHECK(hopper_input.stride_b);
TLLM_CHECK(hopper_input.ptr_a);
TLLM_CHECK(hopper_input.ptr_b);
MainloopArguments const mainloop_params = {reinterpret_cast<ElementB const**>(hopper_input.ptr_b),
hopper_input.stride_b, reinterpret_cast<ElementA const**>(hopper_input.ptr_a), hopper_input.stride_a};
typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{
ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)};
epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
// TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS
auto make_epi_args = [&]()
{
if constexpr (FUSION == EpilogueFusion::NONE)
{
auto epi_params = hopper_input.default_epilogue;
return EpilogueArguments{epilogue_scalars, reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c),
hopper_input.stride_c, reinterpret_cast<ElementD**>(epi_params.ptr_d), epi_params.stride_d};
}
else if constexpr (FUSION == EpilogueFusion::FINALIZE)
{
// Parameters for fused finalize
auto epi_params = hopper_input.fused_finalize_epilogue;
return EpilogueArguments{
epilogue_scalars, // Parameters to underlying epilogue
reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c), hopper_input.stride_c, // C params
reinterpret_cast<typename GemmInfo::ElementFinalOutput*>(epi_params.ptr_final_output),
epi_params.stride_final_output, // D (output) params
reinterpret_cast<typename GemmInfo::ElementBias const*>(epi_params.ptr_bias),
epi_params.stride_bias, // Bias params
epi_params.ptr_router_scales, epi_params.stride_router_scales, // Router scales
epi_params.ptr_expert_first_token_offset, // Offset of this expert's token in the router scales
epi_params.ptr_source_token_index, // Index of the source token to sum into
epi_params.num_rows_in_final_output // Number of tokens in the output buffer
};
}
else
{
static_assert(
sizeof(EpilogueArguments) == 0, "Unimplemented fusion provided to SM90+ MoE gemm launcher");
}
};
EpilogueArguments const epilogue_params = make_epi_args();
typename GemmKernel::TileScheduler::Arguments scheduler_args{
1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN};
typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info,
mainloop_params, epilogue_params, hw_info, scheduler_args};
size_t calculated_ws_size = gemm.get_workspace_size(args);
TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size,
"Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size);
auto can_implement = gemm.can_implement(args);
TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
"Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));
auto init_status = gemm.initialize(args, hopper_input.gemm_workspace);
TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
"Failed to initialize cutlass SM90 grouped gemm. Error: "
+ std::string(cutlassGetStatusString(init_status)));
auto run_status = gemm.run(stream);
TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
"Failed to run cutlass SM90 grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
sync_check_cuda_error();
}
else
{
TLLM_THROW("Configuration was disabled by FAST_BUILD");
}
#else // COMPILE_HOPPER_TMA_GEMMS
TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py.");
#endif // COMPILE_HOPPER_TMA_GEMMS
}
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment