Unverified Commit 3ee62235 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

revert the MoE dependence (#3230)

parent 9829e77e
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/tllmException.h"
#include "tensorrt_llm/common/stringUtils.h"
#include <cstdlib>
#if !defined(_MSC_VER)
#include <cxxabi.h>
#include <dlfcn.h>
#include <execinfo.h>
#endif
#include <sstream>
namespace tensorrt_llm::common
{
namespace
{
int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2;
}
#if !defined(_MSC_VER)
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
: std::runtime_error{""}
{
mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES);
auto const trace = getTrace();
std::runtime_error::operator=(
std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())});
}
#else
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
: mNbFrames{}
, std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)}
{
}
#endif
TllmException::~TllmException() noexcept = default;
std::string TllmException::getTrace() const
{
#if defined(_MSC_VER)
return "";
#else
auto const trace = backtrace_symbols(mCallstack.data(), mNbFrames);
std::ostringstream buf;
for (auto i = 1; i < mNbFrames; ++i)
{
Dl_info info;
if (dladdr(mCallstack[i], &info) && info.dli_sname)
{
auto const clearName = demangle(info.dli_sname);
buf << fmtstr("%-3d %*p %s + %zd", i, VOID_PTR_SZ, mCallstack[i], clearName.c_str(),
static_cast<char*>(mCallstack[i]) - static_cast<char*>(info.dli_saddr));
}
else
{
buf << fmtstr("%-3d %*p %s", i, VOID_PTR_SZ, mCallstack[i], trace[i]);
}
if (i < mNbFrames - 1)
buf << std::endl;
}
if (mNbFrames == MAX_FRAMES)
buf << std::endl << "[truncated]";
std::free(trace);
return buf.str();
#endif
}
std::string TllmException::demangle(char const* name)
{
#if defined(_MSC_VER)
return name;
#else
std::string clearName{name};
auto status = -1;
auto const demangled = abi::__cxa_demangle(name, nullptr, nullptr, &status);
if (status == 0)
{
clearName = demangled;
std::free(demangled);
}
return clearName;
#endif
}
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <array>
#include <cstddef>
#include <stdexcept>
#include <string>
#define NEW_TLLM_EXCEPTION(...) \
tensorrt_llm::common::TllmException(__FILE__, __LINE__, tensorrt_llm::common::fmtstr(__VA_ARGS__))
namespace tensorrt_llm::common
{
class TllmException : public std::runtime_error
{
public:
static auto constexpr MAX_FRAMES = 128;
explicit TllmException(char const* file, std::size_t line, std::string const& msg);
~TllmException() noexcept override;
[[nodiscard]] std::string getTrace() const;
static std::string demangle(char const* name);
private:
std::array<void*, MAX_FRAMES> mCallstack{};
int mNbFrames;
};
} // namespace tensorrt_llm::common
/*
* Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstddef>
#include <cstdint>
namespace tensorrt_llm::common
{
std::uintptr_t constexpr kCudaMemAlign = 128;
inline int8_t* alignPtr(int8_t* ptr, uintptr_t to)
{
uintptr_t addr = (uintptr_t) ptr;
if (addr % to)
{
addr += to - addr % to;
}
return (int8_t*) addr;
}
constexpr size_t alignSize(size_t size, size_t to)
{
if ((size % to) != 0U)
{
size += to - size % to;
}
return size;
}
inline int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment)
{
uintptr_t addr = (uintptr_t) ptr;
addr += previousWorkspaceSize;
return alignPtr((int8_t*) addr, alignment);
}
inline int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize)
{
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign);
}
inline int8_t* nextWorkspacePtr(
int8_t* const base, uintptr_t& offset, uintptr_t const size, uintptr_t const alignment = kCudaMemAlign)
{
uintptr_t curr_offset = offset;
uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment;
int8_t* newptr = size == 0 ? nullptr : base + curr_offset;
offset = next_offset;
return newptr;
}
inline int8_t* nextWorkspacePtrWithAlignment(
int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment = kCudaMemAlign)
{
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment);
}
inline size_t calculateTotalWorkspaceSize(
size_t const* workspaces, int count, uintptr_t const alignment = kCudaMemAlign)
{
size_t total = 0;
for (int i = 0; i < count; i++)
{
total += workspaces[i];
if (workspaces[i] % alignment)
{
total += alignment - (workspaces[i] % alignment);
}
}
return total;
}
}; // namespace tensorrt_llm::common
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/util.hpp>
#include <cute/atom/copy_traits.hpp>
#include <cute/numeric/numeric_types.hpp>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10))
#define CUTE_ARCH_RED_F16_SM70_ENABLED
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
#define CUTE_ARCH_RED_VEC_SM90_ENABLED
#define CUTE_ARCH_RED_BF16_SM90_ENABLED
#endif
namespace cute
{
//////////////////////////////////
// Wrapper around CUDA's atomicAdd
//////////////////////////////////
template <class T>
struct TypedAtomicAdd
{
using SRegisters = T[1];
using DRegisters = T[1];
CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst)
{
atomicAdd(&dst, src);
}
};
template <class T>
struct Copy_Traits<TypedAtomicAdd<T>>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
// F16 ADD PTX
//////////////////////////////////
struct SM70_RED_ADD_NOFTZ_F16
{
using SRegisters = uint16_t[1];
using DRegisters = uint16_t[1];
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _16>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _16>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
struct SM70_RED_ADD_NOFTZ_F16x2
{
using SRegisters = uint32_t[1];
using DRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16x2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _32>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _32>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
struct SM90_RED_ADD_NOFTZ_F16x2_V2
{
using SRegisters = uint32_t[2];
using DRegisters = uint64_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _64>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _64>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
struct SM90_RED_ADD_NOFTZ_F16x2_V4
{
using SRegisters = uint32_t[4];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void copy(
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
"r"(src2), "r"(src3));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V4>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _128>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _128>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
// BF16 ADD PTX
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16
{
using SRegisters = uint16_t[1];
using DRegisters = uint16_t[1];
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _16>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _16>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16x2
{
using SRegisters = uint32_t[1];
using DRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _32>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _32>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16x2_V2
{
using SRegisters = uint32_t[2];
using DRegisters = uint64_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _64>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _64>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16x2_V4
{
using SRegisters = uint32_t[4];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void copy(
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
"r"(src2), "r"(src3));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V4>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _128>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _128>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
} // end namespace cute
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates exposing architecture support for multiply-add operations
*/
#pragma once
#include "cutlass_extensions/weight_only_quant_op.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace arch
{
// Tag which triggers MMA which will trigger
struct OpMultiplyAddDequantizeInterleavedBToA;
/*
Below we have extra tags to signal what kind of dequantization we want to do
(per col, scale only fine grained, finegrained with zero). This still lets us
the existing template infrastructure (incl. that in CUTLASS). However, we
split out the template below into OpMultiplyAddDequantizeInterleavedBToA along
with the quantization op before instantiating the GEMM pieces.
Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of
code we need to duplicate.
*/
struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
// The default just forwards the original operator
template <typename MmaOp, WeightOnlyQuantOp QuantOp_>
struct TagOperator
{
using TaggedOperator = MmaOp;
};
// Specializations below attach more information to the operator
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
};
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
};
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
};
// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original
// operator + the extra information. If no extra info was tagged, the dequant op per column scaling
// as a default.
template <typename TaggedMmaOp>
struct DetagOperator
{
using Operator = TaggedMmaOp;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
};
template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_percol_scale>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
};
template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scale>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
};
template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
};
} // namespace arch
} // namespace cutlass
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime_api.h>
#include "cutlass/device_kernel.h"
#include "tensorrt_llm/common/cudaUtils.h"
namespace tensorrt_llm
{
namespace cutlass_extensions
{
template <typename GemmKernel, bool enable_cutlass_3x = false>
inline int compute_occupancy_for_kernel()
{
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size > (48 << 10))
{
cudaFuncAttributes attr;
int device = 0;
int max_smem_per_block = 0;
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
tensorrt_llm::common::check_cuda_error(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
if constexpr (enable_cutlass_3x)
{
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::device_kernel<GemmKernel>));
}
else
{
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel<GemmKernel>));
}
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
{
// This should mean that
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)
// wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this
// configuration.
return 0;
}
if constexpr (enable_cutlass_3x)
{
tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute(
cutlass::device_kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
else
{
tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute(
cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
}
int max_active_blocks = -1;
if constexpr (enable_cutlass_3x)
{
tensorrt_llm::common::check_cuda_error(
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel<GemmKernel>,
128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size));
}
else
{
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
}
return max_active_blocks;
}
} // namespace cutlass_extensions
} // namespace tensorrt_llm
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing elementwise operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cute/numeric/numeric_types.hpp"
#include "cute/tensor.hpp"
#include "cutlass/trace.h"
#include "cutlass_extensions/arch/copy_red_global.hpp"
#include "cutlass_extensions/util/gather_tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace epilogue
{
namespace collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class StrideC_, class ElementD_, class StrideD_, class ThreadEpilogueOp_, class ElementBias, class StrideBias,
class ElementScale, class StrideScale, class EpilogueTile, class SmemLayoutAtomD, class CopyOpR2S, class CopyOpS2R,
class CopyOpR2G>
class EpilogueMoeFusedFinalize
{
public:
using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized;
using DispatchPolicy = PtrArrayNoSmemWarpSpecialized;
using ThreadEpilogueOp = ThreadEpilogueOp_;
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
using ElementIntermediate = typename ThreadEpilogueOp::ElementD;
using ElementC = typename ThreadEpilogueOp::ElementC;
using StrideC = StrideC_;
using InternalStrideC = cute::remove_pointer_t<StrideC>;
using ElementD = ElementD_;
using StrideD = StrideD_;
using InternalStrideD = cute::remove_pointer_t<StrideD>;
static_assert(!is_same_v<InternalStrideC, StrideC>, "Stride C must be a pointer");
static_assert(is_same_v<InternalStrideD, StrideD>, "Stride D must not be a pointer");
using CopyAtomR2S = Copy_Atom<CopyOpR2S, ElementAccumulator>;
using CopyAtomS2R = Copy_Atom<CopyOpS2R, ElementAccumulator>;
using CopyAtomR2G = Copy_Atom<CopyOpR2G, ElementD>;
static constexpr int AlignmentD = CopyAtomR2G::NumValSrc;
using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{}));
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
struct SharedStorage
{
alignas(SmemAlignmentD) cute::ArrayEngine<ElementAccumulator, cosize_v<SmemLayoutD>> smem_D;
};
struct TensorMapStorage
{
};
struct Arguments
{
typename ThreadEpilogueOp::Params thread{};
ElementC const** ptr_C{};
StrideC dC{};
ElementD* ptr_D{};
StrideD dD{};
ElementBias const* ptr_bias;
StrideBias dBias{};
ElementScale const* ptr_scale;
StrideScale dScale{};
int64_t const* group_offset{};
int32_t const* scatter_index{};
cutlass::FastDivmod num_rows_in_final_output;
};
using Params = Arguments;
//
// Methods
//
template <class ProblemShape>
static constexpr Params to_underlying_arguments(
ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace)
{
return args;
}
template <class ProblemShape>
static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0)
{
return 0;
}
template <class ProblemShape>
static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args,
void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr)
{
return cutlass::Status::kSuccess;
}
template <class ProblemShape>
CUTLASS_HOST_DEVICE static bool can_implement(
[[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args)
{
bool implementable = true;
if (problem_shape.is_host_problem_shape_available())
{
// Check alignment for all problem sizes
for (int i = 0; i < problem_shape.groups(); i++)
{
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1);
auto [M, N, K, L] = problem_shape_MNKL;
implementable = implementable
&& cutlass::detail::check_alignment<AlignmentD>(cute::make_shape(M, N, L), InternalStrideD{});
}
}
if (!implementable)
{
CUTLASS_TRACE_HOST(
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global "
"reduction instruction.\n");
}
return implementable;
}
CUTLASS_HOST_DEVICE
EpilogueMoeFusedFinalize(Params const& params_)
: params(params_)
{
}
CUTLASS_DEVICE
bool is_source_needed()
{
// For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta.
return params.ptr_C != nullptr
&& (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0);
}
template <class ProblemShapeMNKL, class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout,
class TiledMma, class ResidueMNK>
CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK,
BlockCoordMNKL blk_coord_mnkl, cute::Tensor<FrgEngine, FrgLayout> const& accumulators, TiledMma tiled_mma,
ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf)
{
using namespace cute;
using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
auto synchronize = [&]()
{ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
// Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl);
auto N = get<1>(problem_shape_mnkl);
auto L = get<3>(problem_shape_mnkl);
auto mma_tile_m = tile_size<0>(tiled_mma);
auto mma_tile_n = tile_size<1>(tiled_mma);
auto epi_tile_m = size<0>(EpilogueTile{});
auto epi_tile_n = size<1>(EpilogueTile{});
CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M");
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
// Batches are managed by using appropriate pointers to C and D matrices
int32_t const mock_L = 1;
int32_t const mock_l_coord = 0;
// Slice to get the tile this CTA is responsible for
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
// If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
// we get the correct alpha/beta values for the current batch/group using group index.
ThreadEpilogueOp epilogue_op(params.thread, l_coord);
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{});
Tensor sD = as_position_independent_swizzle_tensor(sD_);
// Function to scatter output rows
auto& num_rows = params.num_rows_in_final_output;
auto read_scatter_map = IndexedGather(make_gmem_ptr(params.scatter_index + params.group_offset[l_coord]));
auto get_scatter_idx = [&](auto i)
{
auto scatter = read_scatter_map(i);
int quot, rem;
num_rows(quot, rem, scatter);
return rem;
};
// Represent the full output tensor
ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr;
auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{};
Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l)
Tensor mD_mnl = make_gather_tensor(
make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l)
// Use fake shape for bias, it doesn't matter
bool const is_bias_needed = params.ptr_bias != nullptr;
Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias);
Tensor mScale_mnl = make_tensor(
make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale);
Tensor gC_mnl
= local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gD_mnl
= local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gBias_mnl
= local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gScale_mnl
= local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N)
Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N)
Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
// Get the smallest tiled copy we can use to retile the accumulators
TiledCopy tiled_copy_C_atom
= make_tiled_copy_C_atom(Copy_Atom<SM90_U32x4_STSM_N, cutlass::half_t>{}, tiled_mma);
TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom);
auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx);
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N)
Tensor tRS_rD = make_tensor<ElementAccumulator>(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N)
// Make a tiled copy vectorized along major direction of D
auto tiled_s2r = [&]()
{
if constexpr (cutlass::gemm::detail::is_k_major<StrideD>())
{
constexpr int NumThreadsMajor = epi_tile_n / AlignmentD;
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
return make_tiled_copy(CopyAtomS2R{},
Layout<Shape<Int<NumThreadsMinor>, Int<NumThreadsMajor>>, Stride<Int<NumThreadsMajor>, _1>>{},
Layout<Shape<_1, Int<AlignmentD>>>{});
}
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideD>())
{
constexpr int NumThreadsMajor = epi_tile_m / AlignmentD;
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
return make_tiled_copy(CopyAtomS2R{},
Layout<Shape<Int<NumThreadsMajor>, Int<NumThreadsMinor>>, Stride<_1, Int<NumThreadsMajor>>>{},
Layout<Shape<Int<AlignmentD>, _1>>{});
}
else
{
static_assert(cute::is_void_v<StrideD>, "Unsupported D gmem layout.");
}
}();
auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx);
Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
// Allocate intermediate registers for a single subtile
Tensor tSR_rD = make_tensor<ElementAccumulator>(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rD_final = make_tensor<ElementD>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rC = make_tensor<ElementC>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rBias = make_tensor<ElementBias>(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rScale = make_tensor<ElementScale>(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
// Make an identity coordinate tensor for predicating our output MN tile
Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
// epilogue subtile loop
CUTLASS_PRAGMA_UNROLL
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m)
{
CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n)
{
int mma_m = (epi_m * epi_tile_m) / mma_tile_m;
int mma_n = (epi_n * epi_tile_n) / mma_tile_n;
Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n);
int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n);
int r2s_v = epi_n_in_mma * size(tRS_rD);
CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v)
{
tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v);
}
copy(tiled_r2s, tRS_rD, tRS_sD);
synchronize();
copy(tiled_s2r, tSR_sD, tSR_rD);
synchronize();
Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n);
Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n);
Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n);
Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n);
Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n);
if (epilogue_op.is_source_needed())
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < size<1>(tSR_rD); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size<2>(tSR_rD); ++n)
{
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
{
copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n));
if (is_bias_needed)
{
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
}
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(tSR_rD); ++i)
{
auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n));
if (is_bias_needed)
{
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
}
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
}
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
}
}
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < size<1>(tSR_rD); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size<2>(tSR_rD); ++n)
{
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
{
if (is_bias_needed)
{
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
}
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(tSR_rD); ++i)
{
auto epi_value = epilogue_op(tSR_rD(i, m, n));
if (is_bias_needed)
{
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
}
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
}
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
}
}
}
}
}
}
}
private:
Params params;
};
namespace detail
{
template <class Element, class MaxVec>
constexpr auto get_vectorized_atomic_add_op()
{
using namespace cute;
auto constexpr MaxVecSize = size(MaxVec{});
if constexpr (is_same_v<Element, cutlass::half_t>)
{
if constexpr (MaxVecSize >= 8)
{
return SM90_RED_ADD_NOFTZ_F16x2_V4{};
}
else if constexpr (MaxVecSize >= 4)
{
return SM90_RED_ADD_NOFTZ_F16x2_V2{};
}
else if constexpr (MaxVecSize >= 2)
{
return SM70_RED_ADD_NOFTZ_F16x2{};
}
else
{
return SM70_RED_ADD_NOFTZ_F16{};
}
}
else if constexpr (is_same_v<Element, cutlass::bfloat16_t>)
{
if constexpr (MaxVecSize >= 8)
{
return SM90_RED_ADD_NOFTZ_BF16x2_V4{};
}
else if constexpr (MaxVecSize >= 4)
{
return SM90_RED_ADD_NOFTZ_BF16x2_V2{};
}
else if constexpr (MaxVecSize >= 2)
{
return SM90_RED_ADD_NOFTZ_BF16x2{};
}
else
{
return SM90_RED_ADD_NOFTZ_BF16{};
}
}
else
{
// non-vectorized atomic add for all other types until supported
return TypedAtomicAdd<Element>{};
}
}
} // namespace detail
template <class TileShape, class ElementC, class StrideC, class ElementD, class StrideD, class ElementAccumulator,
class ElementCompute, class ElementBias, class StrideBias, class ElementScale, class StrideScale>
struct EpilogueMoeFusedFinalizeBuilder
{
// assuming cooperative kernel schedule
using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{}));
using EpilogueTile = Shape<_128, EpiTileN>;
// Output of linear combination is ElementCompute instead of ElementD
// since we will be doing more computate on it, no need to cast yet.
using ThreadEpilogueOp
= cutlass::epilogue::thread::LinearCombination<ElementCompute, 1, ElementAccumulator, ElementCompute,
cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, ElementC>;
using SmemLayoutAtomD
= decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<StrideD, ElementAccumulator, EpilogueTile>());
using CopyAtomR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator<StrideD, ElementAccumulator>());
using CopyAtomS2R = DefaultCopy;
using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op<ElementD, EpiTileN>());
template <class EpilogueOp>
struct Sm90TmaWarpSpecializedAdapterWithSmemStorage : detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>
{
// We need to override this one using declaration because otherwise we double up on the smem
using TensorMapStorage = typename EpilogueOp::TensorMapStorage;
using Base = detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>;
CUTLASS_HOST_DEVICE
Sm90TmaWarpSpecializedAdapterWithSmemStorage(
typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors)
: Base(params)
{
}
// These functions depend on the type of TensorMapStorage
template <bool IsLoad>
CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormap,
[[maybe_unused]] typename EpilogueOp::Params const& params,
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t next_batch)
{
}
template <bool IsLoad>
CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormap,
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] uint32_t lane_predicate)
{
}
};
using CollectiveOp = Sm90TmaWarpSpecializedAdapterWithSmemStorage<
EpilogueMoeFusedFinalize<StrideC, ElementD, StrideD, ThreadEpilogueOp, ElementBias, StrideBias, ElementScale,
StrideScale, EpilogueTile, SmemLayoutAtomD, CopyAtomR2S, CopyAtomS2R, CopyAtomR2G>>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace collective
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination with a maximum operation used by epilogues.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/functional.h"
#include "cutlass/half.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace epilogue
{
namespace thread
{
/////////////////////////////////////////////////////////////////////////////////////////////////
__forceinline__ __device__ float copysignf_pos(float a, float b)
{
float r;
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
return r;
}
__forceinline__ __device__ float tanh_opt(float x)
{
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
float const exp_val = -1.f * fabs(2 * x);
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
#else
return fast_tanh(x);
#endif
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct GELU_taylor<float>
{
static bool const kIsHeavy = true;
CUTLASS_DEVICE
float operator()(float const& z) const
{
float k0 = float(0.7978845608028654);
float k1 = float(0.044715);
return float(cutlass::constants::half<float>() * z
* (cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
}
using Params = LinearCombinationGenericParams<float>;
CUTLASS_DEVICE
float operator()(float const& scalar, Params const& params_) const
{
return this->operator()(scalar);
}
};
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column.
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/numeric_conversion.h"
#include "tensorrt_llm/common/quantization.h"
namespace tk = tensorrt_llm::common;
namespace cutlass
{
namespace epilogue
{
namespace threadblock
{
template <typename ThreadblockShape_, int ThreadCount, typename ScaleTileIterator_, typename OutputTileIterator_,
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_, bool UseMasking_ = false>
class EpilogueVisitorPerRowPerCol
{
public:
using ThreadblockShape = ThreadblockShape_;
static int const kThreadCount = ThreadCount;
using ScaleTileIterator = ScaleTileIterator_;
using OutputTileIterator = OutputTileIterator_;
using ElementwiseFunctor = ElementwiseFunctor_;
static int const kIterations = OutputTileIterator::kIterations;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ElementOutput = typename OutputTileIterator::Element;
using LayoutOutput = cutlass::layout::RowMajor;
using ElementAccumulator = ElementAccumulator_;
using AlphaScaleElementType = typename ScaleTileIterator::Element;
using ElementCompute = ElementCompute_;
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
/// Argument structure
struct Arguments
{
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
Arguments()
: batch_stride_alpha(0)
, batch_stride_C(0)
, batch_stride_D(0)
{
}
Arguments(typename ElementwiseFunctor::Params elementwise_)
: elementwise(elementwise_)
, batch_stride_alpha(0)
, batch_stride_C(0)
, batch_stride_D(0)
{
}
Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_,
int64_t batch_stride_C_, int64_t batch_stride_D_)
: elementwise(elementwise_)
, batch_stride_alpha(batch_stride_alpha_)
, batch_stride_C(batch_stride_C_)
, batch_stride_D(batch_stride_D_)
{
}
};
struct Params
{
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
Params(Arguments const& args)
: elementwise(args.elementwise)
, batch_stride_alpha(args.batch_stride_alpha)
, batch_stride_C(args.batch_stride_C)
, batch_stride_D(args.batch_stride_D)
{
}
};
/// Shared storage
struct SharedStorage
{
};
private:
Params const& params_;
SharedStorage& shared_storage_;
MatrixCoord extent_;
MatrixCoord extent_real_;
ElementwiseFunctor elementwise_;
bool const per_token_quant_;
bool const per_channel_quant_;
AlphaScaleElementType* ptr_alpha_row_;
AlphaScaleElementType* ptr_alpha_col_;
ScaleTileIterator iterator_alpha_col_;
OutputTileIterator iterator_C_;
OutputTileIterator iterator_D_;
AlphaScaleElementType element_alpha_row_ = 1.0f;
AlphaScaleElementType element_alpha_col_ = 1.0f;
typename ScaleTileIterator::Fragment fragment_alpha_col_;
typename OutputTileIterator::Fragment fragment_C_;
typename OutputTileIterator::Fragment fragment_D_;
ElementAccumulator beta_;
int column_offset_;
MatrixCoord thread_offset_;
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage,
cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx,
typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C,
typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row,
AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C,
typename OutputTileIterator::Element* ptr_D,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0,
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
: params_(params)
, shared_storage_(shared_storage)
, extent_(problem_size)
, elementwise_(params.elementwise)
, per_token_quant_(quant_option.hasPerTokenScaling())
, per_channel_quant_(quant_option.hasPerChannelScaling())
, ptr_alpha_row_(ptr_alpha_row)
, ptr_alpha_col_(ptr_alpha_col)
, iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset)
, iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset)
, iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset)
, extent_real_(problem_size_real)
{
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
if (beta_ == ElementAccumulator())
{
iterator_C_.clear_mask();
}
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr))
{
element_alpha_col_ = *ptr_alpha_col_;
}
if (!per_token_quant_ && (ptr_alpha_row_ != nullptr))
{
element_alpha_row_ = *ptr_alpha_row_;
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices)
{ ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx)
{
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void begin_epilogue()
{
if (per_channel_quant_)
{
iterator_alpha_col_.load(fragment_alpha_col_);
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx)
{
fragment_D_.clear();
fragment_C_.clear();
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling)
{
iterator_C_.load(fragment_C_);
++iterator_C_;
}
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx)
{
// load alpha_row in begin_step only when per token(row) scaling is used
if (per_token_quant_)
{
int thread_offset_row
= iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
}
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum)
{
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
ComputeFragment result = source_converter(accum);
if (per_channel_quant_)
{
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[column_idx];
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
}
else
{
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
}
// Convert to the output
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
output = output_converter(result);
}
/// Called at the end of a row
CUTLASS_DEVICE
void end_row(int row_idx) {}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx)
{
iterator_D_.store(fragment_D_);
++iterator_D_;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {}
private:
CUTLASS_DEVICE
ComputeFragment per_token_channel_scale_accumulator_(
ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row)
{
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i)
{
result[i] = accum[i] * (scale_col[i] * scale_row);
}
return result;
}
CUTLASS_DEVICE
ComputeFragment per_token_scale_accumulator_(
ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row)
{
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i)
{
result[i] = accum[i] * (scale_col * scale_row);
}
return result;
}
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/platform/platform.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/epilogue/thread/reduction_op.h"
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
#include "cutlass/layout/permute.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace epilogue
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
namespace detail
{
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
template <typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape,
ThreadMap>
{
using WarpTileIterator
= cutlass::epilogue::warp::TileIteratorTensorOpMixed<WarpShape, InstructionShape, int32_t, 32, 16, 8, 8>;
using SharedLoadIterator
= cutlass::epilogue::threadblock::SharedLoadIteratorMixed<ThreadMap, int32_t, 32, 16, 8, 8>;
static int const kFragmentsPerIteration = 2;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load output tile from shared memory in epilogue.
///
/// Satisfies: ReadableTileIterator
///
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
>
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8>
{
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using Element = int32_t;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
static int const kThreads = ThreadMap::kThreads;
/// Fragment object
using Fragment = Array<Element,
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup
* ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
/// Memory access size
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
/// Vector type used for SMEM loads
using LoadType = AlignedArray<Element, const_min(128 / sizeof_bits<Element>::value, ThreadMap::kElementsPerAccess),
const_min(16, kAlignment)>;
static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
private:
//
// Data members
//
/// Byte-level pointer
LoadType const* pointers_[kLoadsPerAccess];
/// Stride along adjacent rows in units of LoadType
int stride_;
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
SharedLoadIteratorMixed(TensorRef ref, int thread_idx)
: stride_((ref.stride(0) / LoadType::kElements))
{
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
// Initialize pointers
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i)
{
pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess;
col_idx += (bank_offset + i) % kLoadsPerAccess;
pointers_[i] += thread_offset.row() * stride_ + col_idx;
}
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset)
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i)
{
pointers_[i] += pointer_offset / LoadType::kElements;
}
}
CUTLASS_DEVICE
void add_tile_offset(TensorCoord const& offset)
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i)
{
pointers_[i]
+= offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements;
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const
{
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster)
{
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group)
{
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row)
{
int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_
+ group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_
+ pointer_offset / LoadType::kElements;
int frag_row_idx
= (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column)
{
int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kLoadsPerAccess; ++v)
{
int vector_idx
= (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess);
LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx];
}
}
}
}
}
}
/// Loads a fragment
CUTLASS_DEVICE
void load(Fragment& frag) const
{
load_with_pointer_offset(frag, 0);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* @file epilogue_helpers.h
*
* This file includes types for the epilogues. The empty structs exist so we can signal to template
* code the type of epilogue we want to run, and let the underlying code specify the details such as
* element types, accumulator type and elements per vector access.
*
*/
#pragma once
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "cutlass_extensions/epilogue/thread/fused_activations.h"
#include <cutlass/epilogue/fusion/operations.hpp>
namespace tensorrt_llm
{
namespace cutlass_extensions
{
struct EpilogueOpBiasSilu
{
};
struct EpilogueOpBiasReLU
{
};
struct EpilogueOpBiasFtGelu
{
};
struct EpilogueOpBias
{
};
struct EpilogueOpDefaultSilu
{
};
struct EpilogueOpDefaultReLU
{
};
struct EpilogueOpDefaultFtGelu
{
};
struct EpilogueOpDefault
{
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator, typename Op>
struct Epilogue
{
static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag");
};
constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling;
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasSilu>
{
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU>
{
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu>
{
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, BiasScaleMode,
cutlass::FloatRoundStyle::round_to_nearest, true>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias>
{
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, BiasScaleMode>;
};
constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default;
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultSilu>
{
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultReLU>
{
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultFtGelu>
{
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, DefaultScaleMode,
cutlass::FloatRoundStyle::round_to_nearest, true>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefault>
{
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, DefaultScaleMode>;
};
} // namespace cutlass_extensions
} // namespace tensorrt_llm
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/collective/builders/sm90_common.inl"
// SM90 Collective Builders should be used only starting CUDA 12.0
#if (__CUDACC_VER_MAJOR__ >= 12)
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail
{
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template <int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, bool SwapAB, int carveout_bytes>
constexpr int compute_stage_count_or_override_gated(StageCountAutoCarveout<carveout_bytes> stage_count)
{
// 32 bytes to account for barriers etc.
constexpr int stage_barrier_bytes = 32;
constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
constexpr int stage_bytes = [&]() -> int
{
if constexpr (SwapAB)
{
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8
+ (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes;
}
else
{
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8
+ (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + stage_barrier_bytes;
}
}();
return (CapacityBytes - carveout_bytes) / stage_bytes;
}
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
class KernelScheduleType, template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB,
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType,
Activation, SwapAB,
cute::enable_if_t<(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>
|| cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>) &&not detail::
is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()>>
{
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
static constexpr bool IsArrayOfPointersGemm
= (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n");
// For fp32 types, map to tf32 MMA value type
using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
using MmaElementB = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>
|| IsArrayOfPointersGemm,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<MmaElementA, MmaElementB,
ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<GmmaMajorA, MmaElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<GmmaMajorB, MmaElementB,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages
= detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB,
TileShape_MNK, SwapAB>(StageCountType{});
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
/* For FP8 use a separate mainloop compared to other datatypes */
cute::conditional_t<IsFP8Input,
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
class KernelScheduleType, template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB,
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType,
Activation, SwapAB,
cute::enable_if_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>
|| cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>>
{
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Not meet TMA alignment requirement yet\n");
static_assert(
detail::is_input_fp8<ElementA, ElementB>(), "Only FP8 datatypes are compatible with these kernel schedules\n");
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
"Not supported for fp8 non-TN warp specialized kernels yet\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
static constexpr bool IsArrayOfPointersGemm
= (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
using AtomLayoutMNK
= cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>
|| IsArrayOfPointersGemm,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<GmmaMajorA, ElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<GmmaMajorB, ElementB,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages
= detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, ElementA, ElementB,
TileShape_MNK, SwapAB>(StageCountType{});
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp"
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
class GmemLayoutB, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType, template <class /* ElementCompute */> class Activation,
bool SwapAB = false, class Enable = void>
struct CollectiveBuilderGated
{
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/detail/dependent_false.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA, class ElementB, class StrideB,
class TiledMma, class GmemTiledCopyA, class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB, class TransformB,
template <class /* ElementCompute */> class Activation, bool SwapAB = false>
struct CollectiveMmaGated
{
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp"
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <int Stages, class ClusterShape, class KernelSchedule, class TileShape_, class ElementA_, class StrideA_,
class ElementB_, class StrideB_, class TiledMma_, class GmemTiledCopyA_, class SmemLayoutAtomA_,
class SmemCopyAtomA_, class TransformA_, class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
class TransformB_, template <class /* ElementCompute */> class Activation_, bool SwapAB_>
struct CollectiveMmaGated<MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>, TileShape_,
ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_>
{
static constexpr bool isGated = true;
static constexpr bool SwapAB = SwapAB_;
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
using StrideB = StrideB_;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using Activation = Activation_<ElementAccumulator>;
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA, typename TiledMma::ValTypeB>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value
&& cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
static_assert(
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
using InternalElementAux = cute::conditional_t<SwapAB, InternalElementA, InternalElementB>;
struct SharedStorage
{
struct TensorStorage : cute::aligned_struct<128>
{
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Host side kernel arguments
struct Arguments
{
ElementA const* ptr_A;
StrideA dA;
ElementB const* ptr_B;
StrideB dB;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
// Device side kernel params
struct Params
{
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{},
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{},
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_Aux tma_load_aux;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params to_underlying_arguments(
ProblemShape const& problem_shape, Arguments const& args, void* workspace)
{
(void) workspace;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a,
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b,
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
if constexpr (SwapAB)
{
auto ptr_Aux = reinterpret_cast<InternalElementA const*>(args.ptr_A + size(make_shape(M, K, L)));
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux,
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1};
}
else
{
auto ptr_Aux = reinterpret_cast<InternalElementB const*>(args.ptr_B + size(make_shape(N, K, L)));
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux,
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1};
}
}
template <class ProblemShape>
static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args)
{
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable = implementable
&& cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M, K, L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable
&& cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N, K, L), StrideB{});
if (!implementable)
{
CUTLASS_TRACE_HOST(
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytes
= (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value)) / 8
+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value)) / 8
+ (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast<uint32_t>(sizeof_bits<ElementAux>::value))
/ 8;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params)
{
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
/// Returned tuple must contain at least two elements, with the first two elements being:
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
/// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
/// The rest of the tensors can be specified as needed by this collective.
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const
{
using X = Underscore;
// Separate out problem shape for convenience
auto [M, N, K, L] = problem_shape_MNKL;
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
if constexpr (SwapAB)
{
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
Tensor gAux_xkl
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
else
{
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
Tensor gAux_xkl
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <class TensorA, class TensorB, class TensorAux, class KTileIterator, class BlockCoord>
CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB, TensorAux> const& load_inputs, BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors)
{
int lane_predicate = cute::elect_one_sync();
if (lane_predicate)
{
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id
= {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
: mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_aux = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>)
{
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n)
{
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)
{
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m)
{
mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{}));
}
}
if constexpr (SwapAB)
{
mcast_mask_aux = mcast_mask_a;
}
else
{
mcast_mask_aux = mcast_mask_b;
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter),
tAsA(_, _, _, write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter),
tBsB(_, _, _, write_stage));
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter),
tAuxsAux(_, _, _, write_stage));
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write)
{
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate)
{
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0,
FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors,
Params const& mainloop_params)
{
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutAux{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
auto tCsAux = [&]() -> auto
{
if constexpr (SwapAB)
{
return thread_mma.partition_A(sAux);
}
else
{
return thread_mma.partition_B(sAux);
}
}();
auto tCrAux = [&]() -> auto
{
if constexpr (SwapAB)
{
return thread_mma.make_fragment_A(tCsAux);
}
else
{
return thread_mma.make_fragment_B(tCsAux);
}
}();
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
if constexpr (SwapAB)
{
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
}
else
{
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
}
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sAux)); // PIPE
//
// PIPELINED MAIN LOOP
//
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight");
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
int read_stage = smem_pipe_read.index();
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0);
if constexpr (SwapAB)
{
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1);
}
else
{
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1);
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
++smem_pipe_read;
}
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0);
if constexpr (SwapAB)
{
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1);
}
else
{
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1);
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
// UNLOCK smem_pipe_release, done _computing_ on it
pipeline.consumer_release(smem_pipe_release);
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count)
{
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count)
{
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/collective/fp8_accumulation.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <int Stages, class ClusterShape, class KernelSchedule, class TileShape_, class ElementA_, class StrideA_,
class ElementB_, class StrideB_, class TiledMma_, class GmemTiledCopyA_, class SmemLayoutAtomA_,
class SmemCopyAtomA_, class TransformA_, class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
class TransformB_, template <class /* ElementCompute */> class Activation_, bool SwapAB_>
struct CollectiveMmaGated<MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>, TileShape_,
ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_>
{
static constexpr bool isGated = true;
static constexpr bool SwapAB = SwapAB_;
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
using StrideB = StrideB_;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using Activation = Activation_<ElementAccumulator>;
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA, typename TiledMma::ValTypeB>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value
&& cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
static_assert(
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
struct SharedStorage
{
struct TensorStorage : cute::aligned_struct<128>
{
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Host side kernel arguments
struct Arguments
{
ElementA const* ptr_A;
StrideA dA;
ElementB const* ptr_B;
StrideB dB;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
// Device side kernel params
struct Params
{
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{},
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_, _, 0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{},
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_, _, 0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_Aux tma_load_aux;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params to_underlying_arguments(
ProblemShape const& problem_shape, Arguments const& args, void* workspace)
{
(void) workspace;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a,
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b,
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
if constexpr (SwapAB)
{
auto ptr_Aux = reinterpret_cast<ElementA const*>(args.ptr_A + size(make_shape(M, K, L)));
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux,
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval};
}
else
{
auto ptr_Aux = reinterpret_cast<ElementB const*>(args.ptr_B + size(make_shape(N, K, L)));
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux,
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval};
}
}
template <class ProblemShape>
static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args)
{
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable = implementable
&& cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M, K, L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable
&& cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N, K, L), StrideB{});
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA
* instructions. */
implementable = implementable && (args.mma_promotion_interval % 4 == 0);
if (!implementable)
{
CUTLASS_TRACE_HOST(
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytes
= (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value)) / 8
+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value)) / 8
+ (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast<uint32_t>(sizeof_bits<ElementAux>::value))
/ 8;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params)
{
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
/// Returned tuple must contain at least two elements, with the first two elements being:
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
/// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const
{
using X = Underscore;
// Separate out problem shape for convenience
auto [M, N, K, L] = problem_shape_MNKL;
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
if constexpr (SwapAB)
{
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
Tensor gAux_xkl
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
else
{
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
Tensor gAux_xkl
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <class TensorA, class TensorB, class TensorAux, class KTileIterator, class BlockCoord>
CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB, TensorAux> const& load_inputs, BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors)
{
int lane_predicate = cute::elect_one_sync();
if (lane_predicate)
{
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id
= {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
: mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_aux = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>)
{
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n)
{
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)
{
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m)
{
mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{}));
}
}
if constexpr (SwapAB)
{
mcast_mask_aux = mcast_mask_a;
}
else
{
mcast_mask_aux = mcast_mask_b;
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter),
tAsA(_, _, _, write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter),
tBsB(_, _, _, write_stage));
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter),
tAuxsAux(_, _, _, write_stage));
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write)
{
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate)
{
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0,
FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors,
Params const& mainloop_params)
{
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
auto tCsAux = [&]() -> auto
{
if constexpr (SwapAB)
{
return thread_mma.partition_A(sAux);
}
else
{
return thread_mma.partition_B(sAux);
}
}();
auto tCrAux = [&]() -> auto
{
if constexpr (SwapAB)
{
return thread_mma.make_fragment_A(tCsAux);
}
else
{
return thread_mma.make_fragment_B(tCsAux);
}
}();
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
if constexpr (SwapAB)
{
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
}
else
{
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
}
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sAux)); // PIPE
//
// PIPELINED MAIN LOOP
//
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight");
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
GmmaFP8Accumulation accumulation0(accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA));
GmmaFP8Accumulation accumulation1(accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA));
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
if (accumulation0.prepare_if_needed())
{
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
int read_stage = smem_pipe_read.index();
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(
tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0());
if constexpr (SwapAB)
{
cute::gemm(
tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1());
}
else
{
cute::gemm(
tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1());
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
accumulation0.promote_if_needed();
accumulation1.promote_if_needed();
++smem_pipe_read;
}
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
if (accumulation0.prepare_if_needed())
{
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(
tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0());
if constexpr (SwapAB)
{
cute::gemm(
tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1());
}
else
{
cute::gemm(
tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1());
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
accumulation0.promote_if_needed();
accumulation1.promote_if_needed();
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
accumulation0.promote_residue_if_needed();
accumulation1.promote_residue_if_needed();
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count)
{
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count)
{
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
*/
#pragma once
// #include <limits>
#include "cutlass/arch/arch.h"
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace device
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
that feature at the moment.
*/
template <typename GemmKernel_>
class GemmUniversalBaseCompat
{
public:
using GemmKernel = GemmKernel_;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using ElementA = typename GemmKernel::ElementA;
using LayoutA = typename GemmKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
using ElementB = typename GemmKernel::ElementB;
using LayoutB = typename GemmKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
using ElementC = typename GemmKernel::ElementC;
using LayoutC = typename GemmKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
using Operator = typename GemmKernel::Operator;
/// Argument structure
using Arguments = typename GemmKernel::Arguments;
protected:
/// Kernel parameters object
typename GemmKernel::Params params_;
protected:
/// Private helper to obtain the grid dimensions with fix-up for split-K
static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args)
{
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
gemm_k_size = args.problem_size.k();
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
{
int const kAlignK
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size)
{
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
}
public:
/// Constructs the GEMM.
GemmUniversalBaseCompat() {}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const& args)
{
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax))
{
return Status::kErrorInvalidProblem;
}
return GemmKernel::can_implement(args);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const& args)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
size_t workspace_bytes = 0;
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
if (args.mode == GemmUniversalMode::kGemmSplitKParallel)
{
// Split-K parallel always requires a temporary workspace
workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
}
else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1)
{
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
}
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
return workspace_bytes;
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const& args)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n"
<< " result = {" << result << "}");
return result;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
if (smem_size <= (48 << 10))
{
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
if (result == cudaSuccess)
{
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
}
else
{
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
if (smem_capacity < 0)
{
int device_idx = 0;
result = cudaGetDevice(&device_idx);
if (result != cudaSuccess)
{
return -1;
}
cudaDeviceProp properties;
result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess)
{
return -1;
}
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
}
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
return occupancy;
}
CUTLASS_TRACE_HOST(" returning internal error");
return -1;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
size_t workspace_bytes = get_workspace_size(args);
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
if (workspace_bytes)
{
if (!workspace)
{
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
return Status::kErrorWorkspaceNull;
}
if (args.mode == GemmUniversalMode::kGemm)
{
CUTLASS_TRACE_HOST(" clearing device workspace");
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
}
// Get CUDA grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
// Initialize the Params structure
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int*>(workspace));
// Specify shared memory capacity for kernel.
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10))
{
cudaError_t result
= cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const& args, void* workspace = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace)
{
return Status::kErrorWorkspaceNull;
}
params_.update(args, workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
//
// Configure grid and block dimensions
//
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
//
// Launch kernel
//
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
// Launch
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
//
// Query for errors
//
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr)
{
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
{
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess)
{
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
*/
#pragma once
#include <limits>
#include <numeric>
#include <vector>
#include "cutlass/arch/arch.h"
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace device
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T_IN, typename T_OUT>
__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk,
int64_t* splitk_buffer_offsets)
{
// in_tensor: [problem_idx, k_partition, hidden_size]
// Note that different requests of in_tensor might have different hidden_size (=m*n)
// so, we need to use splitk_buffer_offsets.
// out_tensor: problem_idx * [hidden_size]
int const problem_idx = blockIdx.y;
GemmCoord problem = problem_sizes[problem_idx];
int const hidden_size = problem.m() * problem.n();
const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk;
T_OUT* out_tensor_ = out_tensor[problem_idx];
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x)
{
float sum = 0.0f;
for (int k_idx = 0; k_idx < splitk; k_idx++)
{
sum += (float) in_tensor_[k_idx * hidden_size + i];
}
out_tensor_[i] = (T_OUT) (sum);
}
}
/// GEMM Grouped
template <typename BaseKernel_>
class BaseSplitkGrouped
{
public:
using BaseKernel = BaseKernel_;
using ElementA = typename BaseKernel::ElementA;
using LayoutA = typename BaseKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
static int const kAlignmentA = BaseKernel::kAlignmentA;
using ElementB = typename BaseKernel::ElementB;
using LayoutB = typename BaseKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
static int const kAlignmentB = BaseKernel::kAlignmentB;
using ElementC = typename BaseKernel::ElementC;
using LayoutC = typename BaseKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
static int const kAlignmentC = BaseKernel::kAlignmentC;
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle;
using Operator = typename BaseKernel::Operator;
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
using MathOperator = typename WarpMmaOperator::MathOperator;
using OperatorClass = typename WarpMmaOperator::OperatorClass;
using ArchTag = typename WarpMmaOperator::ArchTag;
using ThreadblockShape = typename BaseKernel::Mma::Shape;
using WarpShape = typename BaseKernel::WarpShape;
using InstructionShape = typename BaseKernel::InstructionShape;
static int const kStages = BaseKernel::Mma::kStages;
/// Argument structure
using Arguments = typename BaseKernel::Arguments;
using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo;
protected:
/// Kernel parameters object
typename BaseKernel::Params gemm_params_;
private:
/// Get the number of tiles across all problems in a group
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count)
{
int32_t tiles = 0;
for (int32_t i = 0; i < problem_count; ++i)
{
cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i];
BaseKernel::ProblemVisitor::possibly_transpose_problem(problem);
tiles += problem_tile_count(problem);
}
return tiles;
}
/// Copy from `data` to `workspace`
Status copy_to_workspace(void* workspace, void* data, size_t bytes)
{
cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
if (cuda_error != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
cuda_error = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Precomputes scheduling information for the grouped GEMM
Status precompute(Arguments const& args, int32_t tile_count, void* workspace)
{
size_t workspace_bytes = get_workspace_size(args);
std::vector<uint8_t> host_workspace(workspace_bytes);
BaseKernel::ProblemVisitor::host_precompute(
args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data());
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
}
/// Reorder `data` according to `indices`
template <typename T>
static void reorder_array(T* data, std::vector<size_t> const& indices)
{
// For now, simply create a copy of the data and then copy over to the original.
std::vector<T> copy(indices.size());
for (size_t i = 0; i < indices.size(); ++i)
{
copy.at(i) = data[indices[i]];
}
memcpy(data, copy.data(), indices.size() * sizeof(T));
}
public:
/// Constructs the GEMM.
BaseSplitkGrouped() {}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const& args)
{
return BaseKernel::can_implement(args);
}
/// Get the number of tiles in a problem
static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem)
{
auto grid = BaseKernel::ProblemVisitor::grid_shape(problem);
return BaseKernel::ProblemVisitor::tile_count(grid);
}
/// Get the number of tiles across all problems in a group
static int32_t group_tile_count(Arguments const& args)
{
if (args.host_problem_sizes == nullptr)
{
CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes");
return -1;
}
return group_tile_count(args.host_problem_sizes, args.problem_count);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const& args)
{
size_t total_mn = 0;
for (int i = 0; i < args.problem_count; i++)
{
total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n();
}
size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices;
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
{
workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size(
args.host_problem_sizes, args.problem_count, args.threadblock_count);
}
return workSpaceSize;
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const& args)
{
return dim3(args.threadblock_count, 1, 1);
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1)
{
CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()");
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
cudaError_t result;
if (smem_size > (48 << 10))
{
result = cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
return -1;
}
}
int max_active_blocks = -1;
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<BaseKernel>, BaseKernel::kThreadCount, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Sorts each pointer passed in according to the indices that sort
/// `problem_sizes_ptr` in descending order of problem-K dimension.
static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr,
int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr,
int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr)
{
std::vector<size_t> indices(problem_count);
std::iota(indices.begin(), indices.end(), 0);
std::stable_sort(indices.begin(), indices.end(),
[&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); });
reorder_array(problem_sizes_ptr, indices);
reorder_array(lda_host_ptr, indices);
reorder_array(ldb_host_ptr, indices);
reorder_array(ldc_host_ptr, indices);
reorder_array(ldd_host_ptr, indices);
reorder_array(offset_A_ptr, indices);
reorder_array(offset_B_ptr, indices);
reorder_array(offset_C_ptr, indices);
reorder_array(offset_D_ptr, indices);
}
/// Computes the number of threadblocks to launch for the grouped kernel
static int sufficient(
cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1)
{
// Determine the number of blocks that would be launched to fill up a single
// wave on the GPU with each SM having maximum occupancy.
int device_idx;
cudaError_t result = cudaGetDevice(&device_idx);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result));
return 0;
}
int multiprocessor_count;
result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx);
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result));
return 0;
}
bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count);
if (override_sm_count)
{
available_sm_count = multiprocessor_count;
}
int max_active_blocks = maximum_active_blocks();
if (max_active_blocks <= 0)
{
return 0;
}
int occupancy_based_block_count = available_sm_count * max_active_blocks;
if (problem_sizes_ptr == nullptr || problem_count == 0)
{
return occupancy_based_block_count;
}
int total_tiles = group_tile_count(problem_sizes_ptr, problem_count);
// If the group contains a single problem, launching the exact number of
// threadblocks needed to cover the problem minimizes the work performed
// per threadblock in finding the next tile to compute. We return total_tiles
// unless the user has provided the SM count.
if (problem_count == 1 && override_sm_count)
{
return total_tiles;
}
// Choose between the full wave of threadblocks and the tile count. If there
// are fewer tiles in the group than threadblocks in the full wave, only
// some threadblocks will be assigned tiles. Those threadblocks
// which are not assigned tiles still need to perform the work of iterating through
// problem sizes to determine that they have no work to do. This competes for cycles
// with those threadblocks that are assigned tiles to compute.
return std::min(total_tiles, occupancy_based_block_count);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
{
CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Workspace
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace)
{
return Status::kErrorWorkspaceNull;
}
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
{
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
if (status != Status::kSuccess)
{
return status;
}
gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count);
}
else
{
gemm_params_ = typename BaseKernel::Params(args, workspace);
}
// Specify shared memory capacity for kernel.
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
if (smem_size >= (48 << 10))
{
cudaError_t result
= cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const& args, void* workspace = nullptr)
{
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace)
{
return Status::kErrorWorkspaceNull;
}
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
{
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
if (status != Status::kSuccess)
{
return status;
}
gemm_params_.update(args, workspace, tile_count);
}
else
{
gemm_params_.update(args, workspace);
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr)
{
if (!gemm_params_.problem_visitor.problem_count)
{
return Status::kSuccess;
}
//
// Launch kernel
//
// Launch splitk grouped gemm
{
dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices);
dim3 block(BaseKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
cutlass::Kernel<BaseKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
// Launch splitkReduction
{
dim3 grid(32, gemm_params_.problem_visitor.problem_count);
dim3 block(256);
splitkReduction<<<grid, block, 0, stream>>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split,
gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices,
gemm_params_.splitk_buffer_offsets);
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr)
{
return run(stream);
}
/// Initializes and runs the kernel.
Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr)
{
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess)
{
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM Grouped
template <typename GemmKernel_>
class SplitkGemmGrouped : public BaseSplitkGrouped<GemmKernel_>
{
public:
using GemmKernel = GemmKernel_;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/half.h"
#include "cutlass/layout/matrix.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
namespace cutlass
{
namespace gemm
{
namespace kernel
{
template <typename TypeA, typename TypeB, typename arch, typename Enable = void>
struct MixedGemmArchTraits
{
static_assert(dependent_false<arch>, "Unrecognised parameterization");
};
template <typename Arch>
struct MixedGemmArchTraits<float, float, Arch>
{
static constexpr int Stages = 2;
using OperatorClass = cutlass::arch::OpClassSimt;
using AccType = float;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int ElementsPerAccessA = 1;
static constexpr int ElementsPerAccessB = 1;
static constexpr int ElementsPerAccessC = 1;
static constexpr int ThreadblockK = 8;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// ======================= Turing Traits ==============================
// Note that turing does not have native bfloat support so weights and activations will be casted to fp16
// and compute will happen in fp16 then will be converted for bf16 output.
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm75,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm75>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = typename LayoutDetails::Operator;
};
// ======================= Ampere Traits ==============================
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm80,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm80>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Operator = typename LayoutDetails::Operator;
};
// ======================= Ada Traits ==============================
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
using Operator = typename LayoutDetails::Operator;
};
// FP8 A/B = fp8, C/D = fp32
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::float_e4m3_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::float_e5m2_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
// be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t<TypeA>
using TypeC = __nv_bfloat16;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeC>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
using Operator = typename LayoutDetails::Operator;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
namespace cutlass
{
namespace gemm
{
namespace kernel
{
template <typename arch>
struct Int8GemmArchTraits
{
using OperatorClass = cutlass::arch::OpClassSimt;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
};
// ======================= Turing Traits ==============================
template <>
struct Int8GemmArchTraits<cutlass::arch::Sm75>
{
using OperatorClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
};
// ======================= Ampere Traits ==============================
template <>
struct Int8GemmArchTraits<cutlass::arch::Sm80>
{
using OperatorClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment