Unverified Commit 222ce6f1 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

add tensorrt_llm common and cutlass_extensions as 3rdparty (#3216)


Co-authored-by: default avatarBBuf <35585791+BBuf@users.noreply.github.com>
parent 468d23cf
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nvtx3/nvtx3.hpp>
#include <array>
namespace tensorrt_llm::common::nvtx
{
inline nvtx3::color nextColor()
{
#ifndef NVTX_DISABLE
constexpr std::array kColors{nvtx3::color{0xff00ff00}, nvtx3::color{0xff0000ff}, nvtx3::color{0xffffff00},
nvtx3::color{0xffff00ff}, nvtx3::color{0xff00ffff}, nvtx3::color{0xffff0000}, nvtx3::color{0xffffffff}};
constexpr auto numColors = kColors.size();
static thread_local std::size_t colorId = 0;
auto const color = kColors[colorId];
colorId = colorId + 1 >= numColors ? 0 : colorId + 1;
return color;
#else
return nvtx3::color{0};
#endif
}
} // namespace tensorrt_llm::common::nvtx
#define NVTX3_SCOPED_RANGE_WITH_NAME(range, name) \
::nvtx3::scoped_range range(::tensorrt_llm::common::nvtx::nextColor(), name)
#define NVTX3_SCOPED_RANGE(range) NVTX3_SCOPED_RANGE_WITH_NAME(range##_range, #range)
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "cuda.h"
#include <cstdint>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <functional>
#include <mutex>
#include <thread>
#ifdef _MSC_VER
#define FN_NAME __FUNCTION__
#else
#define FN_NAME __func__
#endif
#if ENABLE_MULTI_DEVICE
std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap()
{
static std::unordered_map<nvinfer1::DataType, ncclDataType_t> dtypeMap = {{nvinfer1::DataType::kFLOAT, ncclFloat32},
{nvinfer1::DataType::kHALF, ncclFloat16}, {nvinfer1::DataType::kBF16, ncclBfloat16}};
return &dtypeMap;
}
namespace
{
// Get NCCL unique ID for a group of ranks.
ncclUniqueId getUniqueId(std::set<int> const& group) noexcept
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
ncclUniqueId id;
if (rank == *group.begin())
{
NCCLCHECK(ncclGetUniqueId(&id));
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
COMM_SESSION.sendValue(id, *it, 0);
}
}
else
{
COMM_SESSION.recvValue(id, *group.begin(), 0);
}
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return id;
}
} // namespace
std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group)
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
static std::map<std::set<int>, std::shared_ptr<ncclComm_t>> commMap;
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
std::ostringstream oss;
int index = 0;
for (auto const& rank : group)
{
if (index != 0)
{
oss << ",";
}
oss << rank;
index++;
}
auto groupStr = oss.str();
auto it = commMap.find(group);
if (it != commMap.end())
{
auto ncclComm = it->second;
TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank);
return ncclComm;
}
TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank);
ncclUniqueId id = getUniqueId(group);
int groupRank = 0;
for (auto const& currentRank : group)
{
if (rank == currentRank)
break;
++groupRank;
}
TLLM_CHECK(groupRank < group.size());
std::shared_ptr<ncclComm_t> ncclComm(new ncclComm_t,
[](ncclComm_t* comm)
{
ncclCommDestroy(*comm);
delete comm;
});
NCCLCHECK(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank));
commMap[group] = ncclComm;
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return ncclComm;
}
#endif // ENABLE_MULTI_DEVICE
void const* tensorrt_llm::common::getCommSessionHandle()
{
#if ENABLE_MULTI_DEVICE
return &COMM_SESSION;
#else
return nullptr;
#endif // ENABLE_MULTI_DEVICE
}
namespace
{
// Get current cuda context, a default context will be created if there is no context.
inline CUcontext getCurrentCudaCtx()
{
CUcontext ctx{};
CUresult err = cuCtxGetCurrent(&ctx);
if (err == CUDA_ERROR_NOT_INITIALIZED || ctx == nullptr)
{
TLLM_CUDA_CHECK(cudaFree(nullptr));
err = cuCtxGetCurrent(&ctx);
}
TLLM_CHECK(err == CUDA_SUCCESS);
return ctx;
}
// Helper to create per-cuda-context singleton managed by std::shared_ptr.
// Unlike conventional singletons, singleton created with this will be released
// when not needed, instead of on process exit.
// Objects of this class shall always be declared static / global, and shall never own CUDA
// resources.
template <typename T>
class PerCudaCtxSingletonCreator
{
public:
using CreatorFunc = std::function<std::unique_ptr<T>()>;
using DeleterFunc = std::function<void(T*)>;
// creator returning std::unique_ptr is by design.
// It forces separation of memory for T and memory for control blocks.
// So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released.
// creator itself must not own CUDA resources. Only the object it creates can.
PerCudaCtxSingletonCreator(CreatorFunc creator, DeleterFunc deleter)
: mCreator{std::move(creator)}
, mDeleter{std::move(deleter)}
{
}
std::shared_ptr<T> operator()()
{
std::lock_guard<std::mutex> lk{mMutex};
CUcontext ctx{getCurrentCudaCtx()};
std::shared_ptr<T> result = mObservers[ctx].lock();
if (result == nullptr)
{
// Create the resource and register with an observer.
result = std::shared_ptr<T>{mCreator().release(),
[this, ctx](T* obj)
{
if (obj == nullptr)
{
return;
}
mDeleter(obj);
// Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
// frequently.
std::shared_ptr<T> observedObjHolder; // Delay destroy to avoid dead lock.
std::lock_guard<std::mutex> lk{mMutex};
// Must check observer again because another thread may created new instance for this ctx just
// before we lock mMutex. We can't infer that the observer is stale from the fact that obj is
// destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic
// operation, and the observer may be changed to observe another instance.
observedObjHolder = mObservers.at(ctx).lock();
if (observedObjHolder == nullptr)
{
mObservers.erase(ctx);
}
}};
mObservers.at(ctx) = result;
}
return result;
}
private:
CreatorFunc mCreator;
DeleterFunc mDeleter;
mutable std::mutex mMutex;
// CUDA resources are per-context.
std::unordered_map<CUcontext, std::weak_ptr<T>> mObservers;
};
template <typename T>
class PerThreadSingletonCreator
{
public:
using CreatorFunc = std::function<std::unique_ptr<T>()>;
using DeleterFunc = std::function<void(T*)>;
// creator returning std::unique_ptr is by design.
// It forces separation of memory for T and memory for control blocks.
// So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released.
// creator itself must not own CUDA resources. Only the object it creates can.
PerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter)
: mCreator{std::move(creator)}
, mDeleter{std::move(deleter)}
{
}
std::shared_ptr<T> operator()()
{
std::lock_guard<std::mutex> lk{mMutex};
std::thread::id thread = std::this_thread::get_id();
std::shared_ptr<T> result = mObservers[thread].lock();
if (result == nullptr)
{
// Create the resource and register with an observer.
result = std::shared_ptr<T>{mCreator().release(),
[this, thread](T* obj)
{
if (obj == nullptr)
{
return;
}
mDeleter(obj);
// Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
// frequently.
std::shared_ptr<T> observedObjHolder; // Delay destroy to avoid dead lock.
std::lock_guard<std::mutex> lk{mMutex};
// Must check observer again because another thread may created new instance for this ctx just
// before we lock mMutex. We can't infer that the observer is stale from the fact that obj is
// destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic
// operation, and the observer may be changed to observe another instance.
observedObjHolder = mObservers.at(thread).lock();
if (observedObjHolder == nullptr)
{
mObservers.erase(thread);
}
}};
mObservers.at(thread) = result;
}
return result;
}
private:
CreatorFunc mCreator;
DeleterFunc mDeleter;
mutable std::mutex mMutex;
// CUDA resources are per-thread.
std::unordered_map<std::thread::id, std::weak_ptr<T>> mObservers;
};
} // namespace
std::shared_ptr<cublasHandle_t> getCublasHandle()
{
static PerThreadSingletonCreator<cublasHandle_t> creator(
[]() -> auto
{
auto handle = std::unique_ptr<cublasHandle_t>(new cublasHandle_t);
TLLM_CUDA_CHECK(cublasCreate(handle.get()));
return handle;
},
[](cublasHandle_t* handle)
{
TLLM_CUDA_CHECK(cublasDestroy(*handle));
delete handle;
});
return creator();
}
std::shared_ptr<cublasLtHandle_t> getCublasLtHandle()
{
static PerThreadSingletonCreator<cublasLtHandle_t> creator(
[]() -> auto
{
auto handle = std::unique_ptr<cublasLtHandle_t>(new cublasLtHandle_t);
TLLM_CUDA_CHECK(cublasLtCreate(handle.get()));
return handle;
},
[](cublasLtHandle_t* handle)
{
TLLM_CUDA_CHECK(cublasLtDestroy(*handle));
delete handle;
});
return creator();
}
std::shared_ptr<tensorrt_llm::common::CublasMMWrapper> getCublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace)
{
static PerThreadSingletonCreator<tensorrt_llm::common::CublasMMWrapper> creator(
[cublasHandle, cublasltHandle, stream, workspace]() -> auto
{
auto wrapper = std::unique_ptr<tensorrt_llm::common::CublasMMWrapper>(
new tensorrt_llm::common::CublasMMWrapper(cublasHandle, cublasltHandle, stream, workspace));
return wrapper;
},
[](tensorrt_llm::common::CublasMMWrapper* wrapper) { delete wrapper; });
return creator();
}
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/workspace.h"
#include <NvInferRuntime.h>
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#if ENABLE_MULTI_DEVICE
#include <nccl.h>
#endif // ENABLE_MULTI_DEVICE
#include <cstring>
#include <map>
#include <memory>
#include <nvml.h>
#include <optional>
#include <set>
#include <string>
#include <unordered_map>
namespace tensorrt_llm::common
{
// Write values into buffer
template <typename T>
void write(char*& buffer, T const& val)
{
std::memcpy(buffer, &val, sizeof(T));
buffer += sizeof(T);
}
// Read values from buffer
template <typename T>
void read(char const*& buffer, T& val)
{
std::memcpy(&val, buffer, sizeof(T));
buffer += sizeof(T);
}
// Like std::unique_ptr, but does not prevent generation of default copy constructor when used as class members.
// The copy constructor produces nullptr. So the plugin default copy constructor will not really copy this, and
// your clone() implementation is responsible for initializing such data members.
// With this we can simplify clone() implementation when there are many data members including at least one unique_ptr.
template <typename T, typename Del = std::default_delete<T>>
class UniqPtrWNullCopy : public std::unique_ptr<T, Del>
{
public:
using std::unique_ptr<T, Del>::unique_ptr;
// for compatibility with std::make_unique
explicit UniqPtrWNullCopy(std::unique_ptr<T, Del>&& src)
: std::unique_ptr<T, Del>::unique_ptr{std::move(src)}
{
}
// copy constructor produces nullptr
UniqPtrWNullCopy(UniqPtrWNullCopy const&)
: std::unique_ptr<T, Del>::unique_ptr{}
{
}
};
// for testing only
void const* getCommSessionHandle();
} // namespace tensorrt_llm::common
inline bool isBuilding()
{
auto constexpr key = "IS_BUILDING";
auto const val = getenv(key);
return val != nullptr && std::string(val) == "1";
}
#if ENABLE_MULTI_DEVICE
#define NCCLCHECK(cmd) \
do \
{ \
ncclResult_t r = cmd; \
if (r != ncclSuccess) \
{ \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap();
std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group);
#endif // ENABLE_MULTI_DEVICE
//! To save GPU memory, all the plugins share the same cublas and cublasLt handle globally.
//! Get cublas and cublasLt handle for current cuda context
std::shared_ptr<cublasHandle_t> getCublasHandle();
std::shared_ptr<cublasLtHandle_t> getCublasLtHandle();
std::shared_ptr<tensorrt_llm::common::CublasMMWrapper> getCublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace);
#ifndef DEBUG
#define PLUGIN_CHECK(status) \
do \
{ \
if (status != 0) \
abort(); \
} while (0)
#define ASSERT_PARAM(exp) \
do \
{ \
if (!(exp)) \
return STATUS_BAD_PARAM; \
} while (0)
#define ASSERT_FAILURE(exp) \
do \
{ \
if (!(exp)) \
return STATUS_FAILURE; \
} while (0)
#define CSC(call, err) \
do \
{ \
cudaError_t cudaStatus = call; \
if (cudaStatus != cudaSuccess) \
{ \
return err; \
} \
} while (0)
#define DEBUG_PRINTF(...) \
do \
{ \
} while (0)
#else
#define ASSERT_PARAM(exp) \
do \
{ \
if (!(exp)) \
{ \
fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \
return STATUS_BAD_PARAM; \
} \
} while (0)
#define ASSERT_FAILURE(exp) \
do \
{ \
if (!(exp)) \
{ \
fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \
return STATUS_FAILURE; \
} \
} while (0)
#define CSC(call, err) \
do \
{ \
cudaError_t cudaStatus = call; \
if (cudaStatus != cudaSuccess) \
{ \
printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \
return err; \
} \
} while (0)
#define PLUGIN_CHECK(status) \
{ \
if (status != 0) \
{ \
DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \
abort(); \
} \
}
#define DEBUG_PRINTF(...) \
do \
{ \
printf(__VA_ARGS__); \
} while (0)
#endif // DEBUG
#define NVML_CHECK(cmd) \
do \
{ \
nvmlReturn_t r = cmd; \
if (r != NVML_SUCCESS) \
{ \
printf("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <float.h>
namespace tensorrt_llm
{
namespace common
{
template <typename T>
struct QuantTypeStaticVals;
template <>
struct QuantTypeStaticVals<int8_t>
{
static constexpr float MAX_VAL = 127.f;
static constexpr float MIN_SCALING_FACTOR = 0.f;
static constexpr float MIN_SCALING_FACTOR_RCP = FLT_MAX;
};
#ifdef ENABLE_FP8
template <>
struct QuantTypeStaticVals<__nv_fp8_e4m3>
{
static constexpr float MAX_VAL = 448.f;
// Ref: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L720
static constexpr float MIN_SCALING_FACTOR = 1.0f / (448.f * 512.f);
static constexpr float MIN_SCALING_FACTOR_RCP = (448.f * 512.f);
};
#endif // ENABLE_FP8
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <array>
#include <assert.h>
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
#include <cooperative_groups/reduce.h>
#else
#include <cooperative_groups.h>
#endif
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <float.h>
#include <type_traits>
namespace cg = cooperative_groups;
namespace tensorrt_llm
{
namespace common
{
template <int VPT>
struct BytesToType;
template <>
struct BytesToType<1>
{
using type = uint8_t;
};
template <>
struct BytesToType<2>
{
using type = uint16_t;
};
template <>
struct BytesToType<4>
{
using type = uint32_t;
};
template <>
struct BytesToType<8>
{
using type = uint64_t;
};
template <>
struct BytesToType<16>
{
using type = float4;
};
template <int Bytes>
__device__ inline void copy(void const* local, void* data)
{
using T = typename BytesToType<Bytes>::type;
T const* in = static_cast<T const*>(local);
T* out = static_cast<T*>(data);
*out = *in;
}
static float constexpr HALF_FLT_MAX = 65504.F;
#define FINAL_MASK 0xffffffff
template <typename T>
__inline__ __device__ T warpReduceSum(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
return val;
}
/* Calculate the sum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T) (0.0f);
val = warpReduceSum<T>(val);
return val;
}
template <typename T>
__inline__ __device__ T warpReduceMax(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
return val;
}
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceMax(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
val = warpReduceMax(val);
return val;
}
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockAllReduceMax(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
val = warpReduceMax(val);
return val;
}
template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T* val)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T) (0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T* val)
{
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val);
if (lane == 0)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++)
{
val[i] = is_mask ? shared[i][lane] : (T) (0.0f);
}
warpReduceSumV2<T, NUM>(val);
return (T) 0.0f;
}
template <typename T, int NUM>
__inline__ __device__ T warpReduceMaxV2(T* val)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
}
return (T) (0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceMaxV2(T* val)
{
static __shared__ T shared[32][NUM];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
warpReduceMaxV2<T, NUM>(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
shared[wid][i] = val[i];
}
}
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++)
{
val[i] = is_mask ? shared[lane][i] : (T) -1e20f;
}
warpReduceMaxV2<T, NUM>(val);
return (T) 0.0f;
}
template <int NUM>
__inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cgBlockReduceSumElements_shm)
{
cg::thread_block cta = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);
int const tid = cta.thread_rank();
int const blockz = blockDim.x;
for (int i = 0; i < NUM; i++)
{
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus<float>());
#else
// TODO Add implementation here
if (threadIdx.x == 0 && blockIdx.x == 0)
{
printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n");
assert(false);
}
#endif
}
cg::sync(cta);
if (tid == 0)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
float beta = 0.0f;
for (int j = 0; j < blockz; j += 32)
{
beta += cgBlockReduceSumElements_shm[i * blockz + j];
}
element_list[i] = beta;
}
}
}
template <typename T, int MAX_K>
struct TopK
{
int p[MAX_K]; // index, being -1 at the tail if the array is not full
T u[MAX_K]; // value in descend order, being -MAX_T_VAL if the element is invalid
__device__ __forceinline__ void insert(T const elem, int const elem_id)
{
if (elem_id < 0)
{
return;
}
// Condition of updating the array
// 1. array is not full
// 2. elem is greater than the smallest (last) element in the array
// 3. elem is equal to the smallest (last) element in the array but its elem_id is smaller
bool const need_update
= (p[MAX_K - 1] == -1 || elem > u[MAX_K - 1] || elem == u[MAX_K - 1] && elem_id < p[MAX_K - 1]);
if (!need_update)
{
return;
}
// Find suitable index for the new element
int i;
for (i = MAX_K - 2; i >= 0; --i)
{
bool const need_decrease = (p[i] == -1 || elem > u[i] || elem == u[i] && elem_id < p[i]);
if (!need_decrease)
break;
}
// Move elements to correct positions
for (int k = MAX_K - 2; k >= i; --k)
{
p[k + 1] = p[k];
u[k + 1] = u[k];
}
p[i] = elem_id;
u[i] = elem;
}
__device__ __forceinline__ void init()
{
T const MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
for (int i = 0; i < MAX_K; i++)
{
p[i] = -1;
u[i] = -MAX_T_VAL;
}
}
};
template <typename T, int MAX_K>
__device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(TopK<T, MAX_K> const& a, TopK<T, MAX_K> const& b)
{
TopK<T, MAX_K> res = a;
for (int i = 0; i < MAX_K; ++i)
res.insert(b.u[i], b.p[i]);
return res;
}
template <typename T>
struct TopK_2
{
int p = -1;
T u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
__device__ __forceinline__ void insert(T elem, int elem_id)
{
if (elem > u)
{
u = elem;
p = elem_id;
}
}
__device__ __forceinline__ void init()
{
u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
p = -1;
}
};
template <typename T>
__device__ __forceinline__ TopK_2<T> reduce_topk_op_2(TopK_2<T> const& a, TopK_2<T> const& b)
{
return a.u > b.u ? a : b;
}
template <typename T>
__device__ __forceinline__ T clamp_inf_for_half(float const input)
{
return input;
}
template <>
__device__ __forceinline__ half clamp_inf_for_half(float const input)
{
// clamp inf values to enable fp16 training
return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000);
}
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <functional>
#include <numeric>
#include <optional>
#include <sstream>
namespace tensorrt_llm::common::stl_utils
{
template <typename TInputIt, typename TOutputIt, typename TBinOp>
constexpr TOutputIt basicInclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, TBinOp op)
{
if (first != last)
{
auto val = *first;
while (true)
{
*dFirst = val;
++dFirst;
++first;
if (first == last)
{
break;
}
val = op(std::move(val), *first);
}
}
return dFirst;
}
template <typename TInputIt, typename TOutputIt>
constexpr TOutputIt inclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst)
{
#if defined(__GNUC__) && __GNUC__ <= 8
return basicInclusiveScan(first, last, dFirst, std::plus<>{});
#else
return std::inclusive_scan(first, last, dFirst);
#endif
}
template <typename TInputIt, typename TOutputIt, typename T, typename TBinOp>
constexpr TOutputIt basicExclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init, TBinOp op)
{
if (first != last)
{
while (true)
{
T tmp{op(init, *first)};
*dFirst = init;
++dFirst;
++first;
if (first == last)
{
break;
}
init = std::move(tmp);
}
}
return dFirst;
}
template <typename TInputIt, typename TOutputIt, typename T>
constexpr TOutputIt exclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init)
{
#if defined(__GNUC__) && __GNUC__ <= 8
return basicExclusiveScan(first, last, dFirst, std::move(init), std::plus<>{});
#else
return std::exclusive_scan(first, last, dFirst, std::move(init));
#endif
}
template <typename T, typename = void>
struct HasOperatorOutput : std::false_type
{
};
template <typename T>
struct HasOperatorOutput<T, std::void_t<decltype((std::declval<std::ostream&>() << std::declval<T>()))>>
: std::true_type
{
};
template <typename T>
std::string toString(T const& t, typename std::enable_if_t<HasOperatorOutput<T>::value, int> = 0)
{
std::ostringstream oss;
oss << t;
return oss.str();
}
template <typename T>
std::string toString(std::optional<T> const& t, typename std::enable_if_t<HasOperatorOutput<T>::value, int> = 0)
{
std::ostringstream oss;
if (t)
{
oss << t.value();
}
else
{
oss << "None";
}
return oss.str();
}
} // namespace tensorrt_llm::common::stl_utils
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/common/assert.h"
#include <cerrno>
#include <cstdarg>
#include <cstring>
#include <iostream>
#include <string>
namespace tensorrt_llm::common
{
namespace
{
std::string vformat(char const* fmt, va_list args)
{
va_list args0;
va_copy(args0, args);
auto const size = vsnprintf(nullptr, 0, fmt, args0);
if (size <= 0)
return "";
std::string stringBuf(size, char{});
auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args);
TLLM_CHECK_WITH_INFO(size2 == size, std::string(std::strerror(errno)));
return stringBuf;
}
} // namespace
std::string fmtstr(char const* format, ...)
{
va_list args;
va_start(args, format);
std::string result = vformat(format, args);
va_end(args);
return result;
};
std::unordered_set<std::string> str2set(std::string const& input, char delimiter)
{
std::unordered_set<std::string> values;
if (!input.empty())
{
std::stringstream valStream(input);
std::string val;
while (std::getline(valStream, val, delimiter))
{
if (!val.empty())
{
values.insert(val);
}
}
}
return values;
};
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <chrono>
#include <iomanip>
#include <sstream>
#include "tensorrt_llm/common/timestampUtils.h"
namespace tensorrt_llm::common
{
std::string getCurrentTimestamp()
{
auto now = std::chrono::system_clock::now();
auto now_t = std::chrono::system_clock::to_time_t(now);
auto tm = *std::localtime(&now_t);
auto epoch_to_now = now.time_since_epoch();
auto seconds = std::chrono::duration_cast<std::chrono::seconds>(epoch_to_now);
auto us = std::chrono::duration_cast<std::chrono::microseconds>(epoch_to_now - seconds);
std::ostringstream stream;
stream << std::put_time(&tm, "%m-%d-%Y %H:%M:%S");
stream << "." << std::setfill('0') << std::setw(6) << us.count();
return stream.str();
}
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
namespace tensorrt_llm::common
{
/// @brief Get the current timestamp in the format "MM-DD-YYYY HH:MM:SS:uuuuuu"
std::string getCurrentTimestamp();
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/tllmException.h"
#include "tensorrt_llm/common/stringUtils.h"
#include <cstdlib>
#if !defined(_MSC_VER)
#include <cxxabi.h>
#include <dlfcn.h>
#include <execinfo.h>
#endif
#include <sstream>
namespace tensorrt_llm::common
{
namespace
{
int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2;
}
#if !defined(_MSC_VER)
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
: std::runtime_error{""}
{
mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES);
auto const trace = getTrace();
std::runtime_error::operator=(
std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())});
}
#else
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
: mNbFrames{}
, std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)}
{
}
#endif
TllmException::~TllmException() noexcept = default;
std::string TllmException::getTrace() const
{
#if defined(_MSC_VER)
return "";
#else
auto const trace = backtrace_symbols(mCallstack.data(), mNbFrames);
std::ostringstream buf;
for (auto i = 1; i < mNbFrames; ++i)
{
Dl_info info;
if (dladdr(mCallstack[i], &info) && info.dli_sname)
{
auto const clearName = demangle(info.dli_sname);
buf << fmtstr("%-3d %*p %s + %zd", i, VOID_PTR_SZ, mCallstack[i], clearName.c_str(),
static_cast<char*>(mCallstack[i]) - static_cast<char*>(info.dli_saddr));
}
else
{
buf << fmtstr("%-3d %*p %s", i, VOID_PTR_SZ, mCallstack[i], trace[i]);
}
if (i < mNbFrames - 1)
buf << std::endl;
}
if (mNbFrames == MAX_FRAMES)
buf << std::endl << "[truncated]";
std::free(trace);
return buf.str();
#endif
}
std::string TllmException::demangle(char const* name)
{
#if defined(_MSC_VER)
return name;
#else
std::string clearName{name};
auto status = -1;
auto const demangled = abi::__cxa_demangle(name, nullptr, nullptr, &status);
if (status == 0)
{
clearName = demangled;
std::free(demangled);
}
return clearName;
#endif
}
} // namespace tensorrt_llm::common
/*
* Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstddef>
#include <cstdint>
namespace tensorrt_llm::common
{
std::uintptr_t constexpr kCudaMemAlign = 128;
inline int8_t* alignPtr(int8_t* ptr, uintptr_t to)
{
uintptr_t addr = (uintptr_t) ptr;
if (addr % to)
{
addr += to - addr % to;
}
return (int8_t*) addr;
}
constexpr size_t alignSize(size_t size, size_t to)
{
if ((size % to) != 0U)
{
size += to - size % to;
}
return size;
}
inline int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment)
{
uintptr_t addr = (uintptr_t) ptr;
addr += previousWorkspaceSize;
return alignPtr((int8_t*) addr, alignment);
}
inline int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize)
{
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign);
}
inline int8_t* nextWorkspacePtr(
int8_t* const base, uintptr_t& offset, uintptr_t const size, uintptr_t const alignment = kCudaMemAlign)
{
uintptr_t curr_offset = offset;
uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment;
int8_t* newptr = size == 0 ? nullptr : base + curr_offset;
offset = next_offset;
return newptr;
}
inline int8_t* nextWorkspacePtrWithAlignment(
int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment = kCudaMemAlign)
{
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment);
}
inline size_t calculateTotalWorkspaceSize(
size_t const* workspaces, int count, uintptr_t const alignment = kCudaMemAlign)
{
size_t total = 0;
for (int i = 0; i < count; i++)
{
total += workspaces[i];
if (workspaces[i] % alignment)
{
total += alignment - (workspaces[i] % alignment);
}
}
return total;
}
}; // namespace tensorrt_llm::common
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/util.hpp>
#include <cute/atom/copy_traits.hpp>
#include <cute/numeric/numeric_types.hpp>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10))
#define CUTE_ARCH_RED_F16_SM70_ENABLED
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
#define CUTE_ARCH_RED_VEC_SM90_ENABLED
#define CUTE_ARCH_RED_BF16_SM90_ENABLED
#endif
namespace cute
{
//////////////////////////////////
// Wrapper around CUDA's atomicAdd
//////////////////////////////////
template <class T>
struct TypedAtomicAdd
{
using SRegisters = T[1];
using DRegisters = T[1];
CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst)
{
atomicAdd(&dst, src);
}
};
template <class T>
struct Copy_Traits<TypedAtomicAdd<T>>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
// F16 ADD PTX
//////////////////////////////////
struct SM70_RED_ADD_NOFTZ_F16
{
using SRegisters = uint16_t[1];
using DRegisters = uint16_t[1];
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _16>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _16>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
struct SM70_RED_ADD_NOFTZ_F16x2
{
using SRegisters = uint32_t[1];
using DRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16x2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _32>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _32>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
struct SM90_RED_ADD_NOFTZ_F16x2_V2
{
using SRegisters = uint32_t[2];
using DRegisters = uint64_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _64>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _64>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
struct SM90_RED_ADD_NOFTZ_F16x2_V4
{
using SRegisters = uint32_t[4];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void copy(
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
"r"(src2), "r"(src3));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V4>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _128>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _128>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
// BF16 ADD PTX
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16
{
using SRegisters = uint16_t[1];
using DRegisters = uint16_t[1];
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _16>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _16>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16x2
{
using SRegisters = uint32_t[1];
using DRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _32>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _32>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16x2_V2
{
using SRegisters = uint32_t[2];
using DRegisters = uint64_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _64>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _64>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16x2_V4
{
using SRegisters = uint32_t[4];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void copy(
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
"r"(src2), "r"(src3));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V4>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _128>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _128>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
} // end namespace cute
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates exposing architecture support for multiply-add operations
*/
#pragma once
#include "cutlass_extensions/weight_only_quant_op.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace arch
{
// Tag which triggers MMA which will trigger
struct OpMultiplyAddDequantizeInterleavedBToA;
/*
Below we have extra tags to signal what kind of dequantization we want to do
(per col, scale only fine grained, finegrained with zero). This still lets us
the existing template infrastructure (incl. that in CUTLASS). However, we
split out the template below into OpMultiplyAddDequantizeInterleavedBToA along
with the quantization op before instantiating the GEMM pieces.
Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of
code we need to duplicate.
*/
struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
// The default just forwards the original operator
template <typename MmaOp, WeightOnlyQuantOp QuantOp_>
struct TagOperator
{
using TaggedOperator = MmaOp;
};
// Specializations below attach more information to the operator
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
};
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
};
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
};
// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original
// operator + the extra information. If no extra info was tagged, the dequant op per column scaling
// as a default.
template <typename TaggedMmaOp>
struct DetagOperator
{
using Operator = TaggedMmaOp;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
};
template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_percol_scale>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
};
template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scale>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
};
template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
};
} // namespace arch
} // namespace cutlass
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime_api.h>
#include "cutlass/device_kernel.h"
#include "tensorrt_llm/common/cudaUtils.h"
namespace tensorrt_llm
{
namespace cutlass_extensions
{
template <typename GemmKernel, bool enable_cutlass_3x = false>
inline int compute_occupancy_for_kernel()
{
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size > (48 << 10))
{
cudaFuncAttributes attr;
int device = 0;
int max_smem_per_block = 0;
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
tensorrt_llm::common::check_cuda_error(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
if constexpr (enable_cutlass_3x)
{
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::device_kernel<GemmKernel>));
}
else
{
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel<GemmKernel>));
}
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
{
// This should mean that
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)
// wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this
// configuration.
return 0;
}
if constexpr (enable_cutlass_3x)
{
tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute(
cutlass::device_kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
else
{
tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute(
cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
}
int max_active_blocks = -1;
if constexpr (enable_cutlass_3x)
{
tensorrt_llm::common::check_cuda_error(
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel<GemmKernel>,
128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size));
}
else
{
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
}
return max_active_blocks;
}
} // namespace cutlass_extensions
} // namespace tensorrt_llm
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing elementwise operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cute/numeric/numeric_types.hpp"
#include "cute/tensor.hpp"
#include "cutlass/trace.h"
#include "cutlass_extensions/arch/copy_red_global.hpp"
#include "cutlass_extensions/util/gather_tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace epilogue
{
namespace collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class StrideC_, class ElementD_, class StrideD_, class ThreadEpilogueOp_, class ElementBias, class StrideBias,
class ElementScale, class StrideScale, class EpilogueTile, class SmemLayoutAtomD, class CopyOpR2S, class CopyOpS2R,
class CopyOpR2G>
class EpilogueMoeFusedFinalize
{
public:
using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized;
using DispatchPolicy = PtrArrayNoSmemWarpSpecialized;
using ThreadEpilogueOp = ThreadEpilogueOp_;
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
using ElementIntermediate = typename ThreadEpilogueOp::ElementD;
using ElementC = typename ThreadEpilogueOp::ElementC;
using StrideC = StrideC_;
using InternalStrideC = cute::remove_pointer_t<StrideC>;
using ElementD = ElementD_;
using StrideD = StrideD_;
using InternalStrideD = cute::remove_pointer_t<StrideD>;
static_assert(!is_same_v<InternalStrideC, StrideC>, "Stride C must be a pointer");
static_assert(is_same_v<InternalStrideD, StrideD>, "Stride D must not be a pointer");
using CopyAtomR2S = Copy_Atom<CopyOpR2S, ElementAccumulator>;
using CopyAtomS2R = Copy_Atom<CopyOpS2R, ElementAccumulator>;
using CopyAtomR2G = Copy_Atom<CopyOpR2G, ElementD>;
static constexpr int AlignmentD = CopyAtomR2G::NumValSrc;
using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{}));
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
struct SharedStorage
{
alignas(SmemAlignmentD) cute::ArrayEngine<ElementAccumulator, cosize_v<SmemLayoutD>> smem_D;
};
struct TensorMapStorage
{
};
struct Arguments
{
typename ThreadEpilogueOp::Params thread{};
ElementC const** ptr_C{};
StrideC dC{};
ElementD* ptr_D{};
StrideD dD{};
ElementBias const* ptr_bias;
StrideBias dBias{};
ElementScale const* ptr_scale;
StrideScale dScale{};
int64_t const* group_offset{};
int32_t const* scatter_index{};
cutlass::FastDivmod num_rows_in_final_output;
};
using Params = Arguments;
//
// Methods
//
template <class ProblemShape>
static constexpr Params to_underlying_arguments(
ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace)
{
return args;
}
template <class ProblemShape>
static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0)
{
return 0;
}
template <class ProblemShape>
static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args,
void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr)
{
return cutlass::Status::kSuccess;
}
template <class ProblemShape>
CUTLASS_HOST_DEVICE static bool can_implement(
[[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args)
{
bool implementable = true;
if (problem_shape.is_host_problem_shape_available())
{
// Check alignment for all problem sizes
for (int i = 0; i < problem_shape.groups(); i++)
{
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1);
auto [M, N, K, L] = problem_shape_MNKL;
implementable = implementable
&& cutlass::detail::check_alignment<AlignmentD>(cute::make_shape(M, N, L), InternalStrideD{});
}
}
if (!implementable)
{
CUTLASS_TRACE_HOST(
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global "
"reduction instruction.\n");
}
return implementable;
}
CUTLASS_HOST_DEVICE
EpilogueMoeFusedFinalize(Params const& params_)
: params(params_)
{
}
CUTLASS_DEVICE
bool is_source_needed()
{
// For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta.
return params.ptr_C != nullptr
&& (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0);
}
template <class ProblemShapeMNKL, class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout,
class TiledMma, class ResidueMNK>
CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK,
BlockCoordMNKL blk_coord_mnkl, cute::Tensor<FrgEngine, FrgLayout> const& accumulators, TiledMma tiled_mma,
ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf)
{
using namespace cute;
using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
auto synchronize = [&]()
{ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
// Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl);
auto N = get<1>(problem_shape_mnkl);
auto L = get<3>(problem_shape_mnkl);
auto mma_tile_m = tile_size<0>(tiled_mma);
auto mma_tile_n = tile_size<1>(tiled_mma);
auto epi_tile_m = size<0>(EpilogueTile{});
auto epi_tile_n = size<1>(EpilogueTile{});
CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M");
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
// Batches are managed by using appropriate pointers to C and D matrices
int32_t const mock_L = 1;
int32_t const mock_l_coord = 0;
// Slice to get the tile this CTA is responsible for
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
// If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
// we get the correct alpha/beta values for the current batch/group using group index.
ThreadEpilogueOp epilogue_op(params.thread, l_coord);
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{});
Tensor sD = as_position_independent_swizzle_tensor(sD_);
// Function to scatter output rows
auto& num_rows = params.num_rows_in_final_output;
auto read_scatter_map = IndexedGather(make_gmem_ptr(params.scatter_index + params.group_offset[l_coord]));
auto get_scatter_idx = [&](auto i)
{
auto scatter = read_scatter_map(i);
int quot, rem;
num_rows(quot, rem, scatter);
return rem;
};
// Represent the full output tensor
ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr;
auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{};
Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l)
Tensor mD_mnl = make_gather_tensor(
make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l)
// Use fake shape for bias, it doesn't matter
bool const is_bias_needed = params.ptr_bias != nullptr;
Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias);
Tensor mScale_mnl = make_tensor(
make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale);
Tensor gC_mnl
= local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gD_mnl
= local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gBias_mnl
= local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gScale_mnl
= local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N)
Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N)
Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
// Get the smallest tiled copy we can use to retile the accumulators
TiledCopy tiled_copy_C_atom
= make_tiled_copy_C_atom(Copy_Atom<SM90_U32x4_STSM_N, cutlass::half_t>{}, tiled_mma);
TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom);
auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx);
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N)
Tensor tRS_rD = make_tensor<ElementAccumulator>(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N)
// Make a tiled copy vectorized along major direction of D
auto tiled_s2r = [&]()
{
if constexpr (cutlass::gemm::detail::is_k_major<StrideD>())
{
constexpr int NumThreadsMajor = epi_tile_n / AlignmentD;
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
return make_tiled_copy(CopyAtomS2R{},
Layout<Shape<Int<NumThreadsMinor>, Int<NumThreadsMajor>>, Stride<Int<NumThreadsMajor>, _1>>{},
Layout<Shape<_1, Int<AlignmentD>>>{});
}
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideD>())
{
constexpr int NumThreadsMajor = epi_tile_m / AlignmentD;
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
return make_tiled_copy(CopyAtomS2R{},
Layout<Shape<Int<NumThreadsMajor>, Int<NumThreadsMinor>>, Stride<_1, Int<NumThreadsMajor>>>{},
Layout<Shape<Int<AlignmentD>, _1>>{});
}
else
{
static_assert(cute::is_void_v<StrideD>, "Unsupported D gmem layout.");
}
}();
auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx);
Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
// Allocate intermediate registers for a single subtile
Tensor tSR_rD = make_tensor<ElementAccumulator>(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rD_final = make_tensor<ElementD>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rC = make_tensor<ElementC>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rBias = make_tensor<ElementBias>(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rScale = make_tensor<ElementScale>(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
// Make an identity coordinate tensor for predicating our output MN tile
Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
// epilogue subtile loop
CUTLASS_PRAGMA_UNROLL
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m)
{
CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n)
{
int mma_m = (epi_m * epi_tile_m) / mma_tile_m;
int mma_n = (epi_n * epi_tile_n) / mma_tile_n;
Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n);
int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n);
int r2s_v = epi_n_in_mma * size(tRS_rD);
CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v)
{
tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v);
}
copy(tiled_r2s, tRS_rD, tRS_sD);
synchronize();
copy(tiled_s2r, tSR_sD, tSR_rD);
synchronize();
Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n);
Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n);
Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n);
Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n);
Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n);
if (epilogue_op.is_source_needed())
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < size<1>(tSR_rD); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size<2>(tSR_rD); ++n)
{
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
{
copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n));
if (is_bias_needed)
{
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
}
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(tSR_rD); ++i)
{
auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n));
if (is_bias_needed)
{
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
}
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
}
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
}
}
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < size<1>(tSR_rD); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size<2>(tSR_rD); ++n)
{
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
{
if (is_bias_needed)
{
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
}
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(tSR_rD); ++i)
{
auto epi_value = epilogue_op(tSR_rD(i, m, n));
if (is_bias_needed)
{
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
}
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
}
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
}
}
}
}
}
}
}
private:
Params params;
};
namespace detail
{
template <class Element, class MaxVec>
constexpr auto get_vectorized_atomic_add_op()
{
using namespace cute;
auto constexpr MaxVecSize = size(MaxVec{});
if constexpr (is_same_v<Element, cutlass::half_t>)
{
if constexpr (MaxVecSize >= 8)
{
return SM90_RED_ADD_NOFTZ_F16x2_V4{};
}
else if constexpr (MaxVecSize >= 4)
{
return SM90_RED_ADD_NOFTZ_F16x2_V2{};
}
else if constexpr (MaxVecSize >= 2)
{
return SM70_RED_ADD_NOFTZ_F16x2{};
}
else
{
return SM70_RED_ADD_NOFTZ_F16{};
}
}
else if constexpr (is_same_v<Element, cutlass::bfloat16_t>)
{
if constexpr (MaxVecSize >= 8)
{
return SM90_RED_ADD_NOFTZ_BF16x2_V4{};
}
else if constexpr (MaxVecSize >= 4)
{
return SM90_RED_ADD_NOFTZ_BF16x2_V2{};
}
else if constexpr (MaxVecSize >= 2)
{
return SM90_RED_ADD_NOFTZ_BF16x2{};
}
else
{
return SM90_RED_ADD_NOFTZ_BF16{};
}
}
else
{
// non-vectorized atomic add for all other types until supported
return TypedAtomicAdd<Element>{};
}
}
} // namespace detail
template <class TileShape, class ElementC, class StrideC, class ElementD, class StrideD, class ElementAccumulator,
class ElementCompute, class ElementBias, class StrideBias, class ElementScale, class StrideScale>
struct EpilogueMoeFusedFinalizeBuilder
{
// assuming cooperative kernel schedule
using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{}));
using EpilogueTile = Shape<_128, EpiTileN>;
// Output of linear combination is ElementCompute instead of ElementD
// since we will be doing more computate on it, no need to cast yet.
using ThreadEpilogueOp
= cutlass::epilogue::thread::LinearCombination<ElementCompute, 1, ElementAccumulator, ElementCompute,
cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, ElementC>;
using SmemLayoutAtomD
= decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<StrideD, ElementAccumulator, EpilogueTile>());
using CopyAtomR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator<StrideD, ElementAccumulator>());
using CopyAtomS2R = DefaultCopy;
using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op<ElementD, EpiTileN>());
template <class EpilogueOp>
struct Sm90TmaWarpSpecializedAdapterWithSmemStorage : detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>
{
// We need to override this one using declaration because otherwise we double up on the smem
using TensorMapStorage = typename EpilogueOp::TensorMapStorage;
using Base = detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>;
CUTLASS_HOST_DEVICE
Sm90TmaWarpSpecializedAdapterWithSmemStorage(
typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors)
: Base(params)
{
}
// These functions depend on the type of TensorMapStorage
template <bool IsLoad>
CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormap,
[[maybe_unused]] typename EpilogueOp::Params const& params,
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t next_batch)
{
}
template <bool IsLoad>
CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormap,
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] uint32_t lane_predicate)
{
}
};
using CollectiveOp = Sm90TmaWarpSpecializedAdapterWithSmemStorage<
EpilogueMoeFusedFinalize<StrideC, ElementD, StrideD, ThreadEpilogueOp, ElementBias, StrideBias, ElementScale,
StrideScale, EpilogueTile, SmemLayoutAtomD, CopyAtomR2S, CopyAtomS2R, CopyAtomR2G>>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace collective
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination with a maximum operation used by epilogues.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/functional.h"
#include "cutlass/half.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace epilogue
{
namespace thread
{
/////////////////////////////////////////////////////////////////////////////////////////////////
__forceinline__ __device__ float copysignf_pos(float a, float b)
{
float r;
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
return r;
}
__forceinline__ __device__ float tanh_opt(float x)
{
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
float const exp_val = -1.f * fabs(2 * x);
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
#else
return fast_tanh(x);
#endif
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct GELU_taylor<float>
{
static bool const kIsHeavy = true;
CUTLASS_DEVICE
float operator()(float const& z) const
{
float k0 = float(0.7978845608028654);
float k1 = float(0.044715);
return float(cutlass::constants::half<float>() * z
* (cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
}
using Params = LinearCombinationGenericParams<float>;
CUTLASS_DEVICE
float operator()(float const& scalar, Params const& params_) const
{
return this->operator()(scalar);
}
};
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column.
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/numeric_conversion.h"
#include "tensorrt_llm/common/quantization.h"
namespace tk = tensorrt_llm::common;
namespace cutlass
{
namespace epilogue
{
namespace threadblock
{
template <typename ThreadblockShape_, int ThreadCount, typename ScaleTileIterator_, typename OutputTileIterator_,
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_, bool UseMasking_ = false>
class EpilogueVisitorPerRowPerCol
{
public:
using ThreadblockShape = ThreadblockShape_;
static int const kThreadCount = ThreadCount;
using ScaleTileIterator = ScaleTileIterator_;
using OutputTileIterator = OutputTileIterator_;
using ElementwiseFunctor = ElementwiseFunctor_;
static int const kIterations = OutputTileIterator::kIterations;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ElementOutput = typename OutputTileIterator::Element;
using LayoutOutput = cutlass::layout::RowMajor;
using ElementAccumulator = ElementAccumulator_;
using AlphaScaleElementType = typename ScaleTileIterator::Element;
using ElementCompute = ElementCompute_;
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
/// Argument structure
struct Arguments
{
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
Arguments()
: batch_stride_alpha(0)
, batch_stride_C(0)
, batch_stride_D(0)
{
}
Arguments(typename ElementwiseFunctor::Params elementwise_)
: elementwise(elementwise_)
, batch_stride_alpha(0)
, batch_stride_C(0)
, batch_stride_D(0)
{
}
Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_,
int64_t batch_stride_C_, int64_t batch_stride_D_)
: elementwise(elementwise_)
, batch_stride_alpha(batch_stride_alpha_)
, batch_stride_C(batch_stride_C_)
, batch_stride_D(batch_stride_D_)
{
}
};
struct Params
{
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
Params(Arguments const& args)
: elementwise(args.elementwise)
, batch_stride_alpha(args.batch_stride_alpha)
, batch_stride_C(args.batch_stride_C)
, batch_stride_D(args.batch_stride_D)
{
}
};
/// Shared storage
struct SharedStorage
{
};
private:
Params const& params_;
SharedStorage& shared_storage_;
MatrixCoord extent_;
MatrixCoord extent_real_;
ElementwiseFunctor elementwise_;
bool const per_token_quant_;
bool const per_channel_quant_;
AlphaScaleElementType* ptr_alpha_row_;
AlphaScaleElementType* ptr_alpha_col_;
ScaleTileIterator iterator_alpha_col_;
OutputTileIterator iterator_C_;
OutputTileIterator iterator_D_;
AlphaScaleElementType element_alpha_row_ = 1.0f;
AlphaScaleElementType element_alpha_col_ = 1.0f;
typename ScaleTileIterator::Fragment fragment_alpha_col_;
typename OutputTileIterator::Fragment fragment_C_;
typename OutputTileIterator::Fragment fragment_D_;
ElementAccumulator beta_;
int column_offset_;
MatrixCoord thread_offset_;
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage,
cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx,
typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C,
typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row,
AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C,
typename OutputTileIterator::Element* ptr_D,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0,
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
: params_(params)
, shared_storage_(shared_storage)
, extent_(problem_size)
, elementwise_(params.elementwise)
, per_token_quant_(quant_option.hasPerTokenScaling())
, per_channel_quant_(quant_option.hasPerChannelScaling())
, ptr_alpha_row_(ptr_alpha_row)
, ptr_alpha_col_(ptr_alpha_col)
, iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset)
, iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset)
, iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset)
, extent_real_(problem_size_real)
{
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
if (beta_ == ElementAccumulator())
{
iterator_C_.clear_mask();
}
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr))
{
element_alpha_col_ = *ptr_alpha_col_;
}
if (!per_token_quant_ && (ptr_alpha_row_ != nullptr))
{
element_alpha_row_ = *ptr_alpha_row_;
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices)
{ ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx)
{
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void begin_epilogue()
{
if (per_channel_quant_)
{
iterator_alpha_col_.load(fragment_alpha_col_);
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx)
{
fragment_D_.clear();
fragment_C_.clear();
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling)
{
iterator_C_.load(fragment_C_);
++iterator_C_;
}
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx)
{
// load alpha_row in begin_step only when per token(row) scaling is used
if (per_token_quant_)
{
int thread_offset_row
= iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
}
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum)
{
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
ComputeFragment result = source_converter(accum);
if (per_channel_quant_)
{
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[column_idx];
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
}
else
{
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
}
// Convert to the output
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
output = output_converter(result);
}
/// Called at the end of a row
CUTLASS_DEVICE
void end_row(int row_idx) {}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx)
{
iterator_D_.store(fragment_D_);
++iterator_D_;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {}
private:
CUTLASS_DEVICE
ComputeFragment per_token_channel_scale_accumulator_(
ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row)
{
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i)
{
result[i] = accum[i] * (scale_col[i] * scale_row);
}
return result;
}
CUTLASS_DEVICE
ComputeFragment per_token_scale_accumulator_(
ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row)
{
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i)
{
result[i] = accum[i] * (scale_col * scale_row);
}
return result;
}
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/platform/platform.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/epilogue/thread/reduction_op.h"
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
#include "cutlass/layout/permute.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace epilogue
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
namespace detail
{
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
template <typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape,
ThreadMap>
{
using WarpTileIterator
= cutlass::epilogue::warp::TileIteratorTensorOpMixed<WarpShape, InstructionShape, int32_t, 32, 16, 8, 8>;
using SharedLoadIterator
= cutlass::epilogue::threadblock::SharedLoadIteratorMixed<ThreadMap, int32_t, 32, 16, 8, 8>;
static int const kFragmentsPerIteration = 2;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load output tile from shared memory in epilogue.
///
/// Satisfies: ReadableTileIterator
///
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
>
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8>
{
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using Element = int32_t;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
static int const kThreads = ThreadMap::kThreads;
/// Fragment object
using Fragment = Array<Element,
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup
* ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
/// Memory access size
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
/// Vector type used for SMEM loads
using LoadType = AlignedArray<Element, const_min(128 / sizeof_bits<Element>::value, ThreadMap::kElementsPerAccess),
const_min(16, kAlignment)>;
static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
private:
//
// Data members
//
/// Byte-level pointer
LoadType const* pointers_[kLoadsPerAccess];
/// Stride along adjacent rows in units of LoadType
int stride_;
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
SharedLoadIteratorMixed(TensorRef ref, int thread_idx)
: stride_((ref.stride(0) / LoadType::kElements))
{
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
// Initialize pointers
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i)
{
pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess;
col_idx += (bank_offset + i) % kLoadsPerAccess;
pointers_[i] += thread_offset.row() * stride_ + col_idx;
}
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset)
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i)
{
pointers_[i] += pointer_offset / LoadType::kElements;
}
}
CUTLASS_DEVICE
void add_tile_offset(TensorCoord const& offset)
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i)
{
pointers_[i]
+= offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements;
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const
{
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster)
{
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group)
{
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row)
{
int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_
+ group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_
+ pointer_offset / LoadType::kElements;
int frag_row_idx
= (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column)
{
int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kLoadsPerAccess; ++v)
{
int vector_idx
= (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess);
LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx];
}
}
}
}
}
}
/// Loads a fragment
CUTLASS_DEVICE
void load(Fragment& frag) const
{
load_with_pointer_offset(frag, 0);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* @file epilogue_helpers.h
*
* This file includes types for the epilogues. The empty structs exist so we can signal to template
* code the type of epilogue we want to run, and let the underlying code specify the details such as
* element types, accumulator type and elements per vector access.
*
*/
#pragma once
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "cutlass_extensions/epilogue/thread/fused_activations.h"
#include <cutlass/epilogue/fusion/operations.hpp>
namespace tensorrt_llm
{
namespace cutlass_extensions
{
struct EpilogueOpBiasSilu
{
};
struct EpilogueOpBiasReLU
{
};
struct EpilogueOpBiasFtGelu
{
};
struct EpilogueOpBias
{
};
struct EpilogueOpDefaultSilu
{
};
struct EpilogueOpDefaultReLU
{
};
struct EpilogueOpDefaultFtGelu
{
};
struct EpilogueOpDefault
{
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator, typename Op>
struct Epilogue
{
static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag");
};
constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling;
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasSilu>
{
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU>
{
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu>
{
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, BiasScaleMode,
cutlass::FloatRoundStyle::round_to_nearest, true>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias>
{
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, BiasScaleMode>;
};
constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default;
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultSilu>
{
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultReLU>
{
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultFtGelu>
{
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, DefaultScaleMode,
cutlass::FloatRoundStyle::round_to_nearest, true>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefault>
{
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, DefaultScaleMode>;
};
} // namespace cutlass_extensions
} // namespace tensorrt_llm
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/collective/builders/sm90_common.inl"
// SM90 Collective Builders should be used only starting CUDA 12.0
#if (__CUDACC_VER_MAJOR__ >= 12)
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail
{
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template <int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, bool SwapAB, int carveout_bytes>
constexpr int compute_stage_count_or_override_gated(StageCountAutoCarveout<carveout_bytes> stage_count)
{
// 32 bytes to account for barriers etc.
constexpr int stage_barrier_bytes = 32;
constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
constexpr int stage_bytes = [&]() -> int
{
if constexpr (SwapAB)
{
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8
+ (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes;
}
else
{
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8
+ (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + stage_barrier_bytes;
}
}();
return (CapacityBytes - carveout_bytes) / stage_bytes;
}
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
class KernelScheduleType, template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB,
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType,
Activation, SwapAB,
cute::enable_if_t<(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>
|| cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>) &&not detail::
is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()>>
{
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
static constexpr bool IsArrayOfPointersGemm
= (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n");
// For fp32 types, map to tf32 MMA value type
using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
using MmaElementB = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>
|| IsArrayOfPointersGemm,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<MmaElementA, MmaElementB,
ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<GmmaMajorA, MmaElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<GmmaMajorB, MmaElementB,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages
= detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB,
TileShape_MNK, SwapAB>(StageCountType{});
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
/* For FP8 use a separate mainloop compared to other datatypes */
cute::conditional_t<IsFP8Input,
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
class KernelScheduleType, template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB,
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType,
Activation, SwapAB,
cute::enable_if_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>
|| cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>>
{
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Not meet TMA alignment requirement yet\n");
static_assert(
detail::is_input_fp8<ElementA, ElementB>(), "Only FP8 datatypes are compatible with these kernel schedules\n");
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
"Not supported for fp8 non-TN warp specialized kernels yet\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
static constexpr bool IsArrayOfPointersGemm
= (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
using AtomLayoutMNK
= cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>
|| IsArrayOfPointersGemm,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<GmmaMajorA, ElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<GmmaMajorB, ElementB,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages
= detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, ElementA, ElementB,
TileShape_MNK, SwapAB>(StageCountType{});
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment