Unverified Commit 3ee62235 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

revert the MoE dependence (#3230)

parent 9829e77e
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/assert.h"
namespace
{
bool initCheckDebug()
{
auto constexpr kDebugEnabled = "TLLM_DEBUG_MODE";
auto const debugEnabled = std::getenv(kDebugEnabled);
return debugEnabled && debugEnabled[0] == '1';
}
} // namespace
bool DebugConfig::isCheckDebugEnabled()
{
static bool const debugEnabled = initCheckDebug();
return debugEnabled;
}
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/common/tllmException.h"
#include <string>
namespace tensorrt_llm::common
{
[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "")
{
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str()));
}
} // namespace tensorrt_llm::common
class DebugConfig
{
public:
static bool isCheckDebugEnabled();
};
#if defined(_WIN32)
#define TLLM_LIKELY(x) (__assume((x) == 1), (x))
#define TLLM_UNLIKELY(x) (__assume((x) == 0), (x))
#else
#define TLLM_LIKELY(x) __builtin_expect((x), 1)
#define TLLM_UNLIKELY(x) __builtin_expect((x), 0)
#endif
#define TLLM_CHECK(val) \
do \
{ \
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
} while (0)
#define TLLM_CHECK_WITH_INFO(val, info, ...) \
do \
{ \
TLLM_LIKELY(static_cast<bool>(val)) \
? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError( \
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \
} while (0)
#define TLLM_CHECK_DEBUG(val) \
do \
{ \
if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
{ \
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
} \
} while (0)
#define TLLM_CHECK_DEBUG_WITH_INFO(val, info, ...) \
do \
{ \
if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
{ \
TLLM_LIKELY(static_cast<bool>(val)) \
? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError( \
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \
} \
} while (0)
#define TLLM_THROW(...) \
do \
{ \
throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \
} while (0)
#define TLLM_WRAP(ex) \
NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what())
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cublasVersionCheck.h"
#include <algorithm>
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#endif
namespace tensorrt_llm
{
namespace common
{
CublasMMWrapper::CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace)
: mCublasHandle(cublasHandle)
, mCublasLtHandle(cublasltHandle)
, mStream(stream)
, mCublasWorkspace(workspace)
{
}
CublasMMWrapper::~CublasMMWrapper() {}
CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper)
: mCublasHandle(wrapper.mCublasHandle)
, mCublasLtHandle(wrapper.mCublasLtHandle)
, mStream(wrapper.mStream)
{
}
void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
int const k, int const lda, int const ldb, int const ldc, int8_t fastAcc)
{
// --------------------------------------
// Create descriptors for the original matrices
check_cuda_error(
cublasLtMatrixLayoutCreate(&mADesc, mAType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
check_cuda_error(
cublasLtMatrixLayoutCreate(&mBDesc, mBType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
check_cuda_error(cublasLtMatrixLayoutCreate(&mCDesc, mCType, m, n, ldc));
check_cuda_error(cublasLtMatmulDescCreate(&mOperationDesc, mComputeType, mScaleType));
check_cuda_error(cublasLtMatmulDescSetAttribute(
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof(int8_t)));
}
void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b)
{
check_cuda_error(
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(void*)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(void*)));
}
void CublasMMWrapper::destroyDescriptors()
{
check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(mADesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(mBDesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(mCDesc));
mOperationDesc = NULL;
mADesc = NULL;
mBDesc = NULL;
mCDesc = NULL;
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc)
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f);
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc,
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
{
if (heuristic)
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, /* hasAlgo */ (*heuristic).algo,
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
/* usingCublasLt */ true);
}
else
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, /* hasAlgo */ false,
/* usingCublasLt */ true);
}
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
{
if (heuristic)
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, /* hasAlgo */ (*heuristic).algo,
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
/* usingCublasLt */ true);
}
else
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
/* usingCublasLt */ true);
}
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta)
{
bool usingCublasLt = mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3;
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
/* usingCublasLt */ usingCublasLt);
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt)
{
half h_alpha = (half) (f_alpha);
half h_beta = (half) (f_beta);
// TODO: default cublas libs
usingCublasLt = usingCublasLt && (mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3);
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F;
int batch_count = 1;
// fp32 use cublas as default
// fp16 use cublasLt as default
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
if (usingCublasLt)
{
if (hasAlgo)
{
hasAlgo = checkTactic(transa, transb, m, n, k, lda, ldb, ldc, algo);
}
check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C,
mCDesc, C, mCDesc, (hasAlgo ? (&algo) : NULL), mCublasWorkspace, workspaceSize, mStream));
sync_check_cuda_error();
}
else
{
check_cuda_error(cublasSetStream(getCublasHandle(), mStream));
check_cuda_error(cublasSetWorkspace(getCublasHandle(), mCublasWorkspace, workspaceSize));
// Go with default heuristic to choose tactic as cuBLAS does not allow to choose tactics in Ampere+
cublasGemmAlgo_t cublasAlgo = CUBLAS_GEMM_DEFAULT;
check_cuda_error(cublasGemmEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, B, mBType, ldb,
beta, C, mCType, ldc, mComputeType, static_cast<cublasGemmAlgo_t>(cublasAlgo)));
sync_check_cuda_error();
}
}
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb,
const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha,
float const f_beta)
{
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda,
strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType,
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA,
void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C,
cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType)
{
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda,
strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType,
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
void CublasMMWrapper::setWorkspace(void* workspace)
{
mCublasWorkspace = workspace;
}
void CublasMMWrapper::setFP32GemmConfig()
{
setGemmConfig(CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F);
}
void CublasMMWrapper::setFP16GemmConfig(cudaDataType_t outputType)
{
setGemmConfig(CUDA_R_16F, CUDA_R_16F, outputType, CUDA_R_32F);
}
#ifdef ENABLE_BF16
void CublasMMWrapper::setBF16GemmConfig(cudaDataType_t outputType)
{
setGemmConfig(CUDA_R_16BF, CUDA_R_16BF, outputType, CUDA_R_32F);
}
#endif
#ifdef ENABLE_FP8
void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
{
setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F);
}
#endif
void CublasMMWrapper::setGemmConfig(
cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType)
{
mAType = aType;
mBType = bType;
mCType = cType;
bool isFp16ComputeType = computeType == CUDA_R_16F;
if (isFp16ComputeType)
{
mComputeType = CUBLAS_COMPUTE_16F;
mScaleType = CUDA_R_16F;
}
else
{
mComputeType = CUBLAS_COMPUTE_32F;
mScaleType = CUDA_R_32F;
}
}
CublasDataType CublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
{
if (data_type == CUDA_R_16F)
{
return HALF_DATATYPE;
}
else if (data_type == CUDA_R_32F)
{
return FLOAT_DATATYPE;
}
else if (data_type == CUDA_R_8I)
{
return INT8_DATATYPE;
}
#ifdef ENABLE_BF16
else if (data_type == CUDA_R_16BF)
{
return BFLOAT16_DATATYPE;
}
#endif
return FLOAT_DATATYPE;
}
void CublasMMWrapper::setStream(cudaStream_t stream)
{
mStream = stream;
}
bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo)
{
TLLM_CHECK_WITH_INFO(
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
cublasLtMatmulHeuristicResult_t heurResult;
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc, &algo, &heurResult);
if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS
|| heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE)
{
return false;
}
sync_check_cuda_error();
return true;
}
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasOperation_t transa,
cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc)
{
TLLM_CHECK_WITH_INFO(
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc);
sync_check_cuda_error();
return heuristics;
}
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc)
{
#if TLLM_CUBLAS_VER_LE(11, 4, 2)
TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2.");
return {};
#else
std::vector<cublasLtMatmulHeuristicResult_t> heuristics(200);
cublasLtMatmulPreference_t preference;
check_cuda_error(cublasLtMatmulPreferenceCreate(&preference));
check_cuda_error(cublasLtMatmulPreferenceInit(preference));
uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)));
// Restrict reduction algorithms for numerical stability and better determinism
uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_MASK;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, &reduction_mask, sizeof(reduction_mask)));
#if TLLM_CUBLAS_VER_LT(12, 0, 0)
uint32_t pointer_mode_mask = 0;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask)));
#endif
int return_count = 0;
check_cuda_error(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
heuristics.size(), heuristics.data(), &return_count));
heuristics.resize(return_count);
return heuristics;
#endif
}
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <map>
#include <optional>
#include <string>
namespace tensorrt_llm
{
namespace common
{
class CublasMMWrapper
{
protected:
std::shared_ptr<cublasHandle_t> mCublasHandle;
std::shared_ptr<cublasLtHandle_t> mCublasLtHandle;
cudaDataType_t mAType{};
cudaDataType_t mBType{};
cudaDataType_t mCType{};
cublasComputeType_t mComputeType{};
cudaDataType_t mScaleType{};
cublasLtMatmulDesc_t mOperationDesc{NULL};
cublasLtMatrixLayout_t mADesc{NULL};
cublasLtMatrixLayout_t mBDesc{NULL};
cublasLtMatrixLayout_t mCDesc{NULL};
cudaStream_t mStream;
void* mCublasWorkspace = nullptr;
private:
bool descriptorsCreated() const
{
return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL;
}
public:
CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle, std::shared_ptr<cublasLtHandle_t> cublasLtHandle,
cudaStream_t stream, void* workspace);
~CublasMMWrapper();
CublasMMWrapper(CublasMMWrapper const& wrapper);
/********************** GEMMs **********************/
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc,
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt);
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB,
void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f,
float const f_beta = 0.0f);
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B,
cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType,
int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType);
/********************** Tactic selection helpers **********************/
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo);
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
int const m, int const n, int const k, int const lda, int const ldb, int const ldc);
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc);
using MatrixLayout = std::tuple<cudaDataType_t, cublasLtOrder_t, uint64_t, uint64_t>;
using cache_idx_t = std::tuple<cublasLtMatmulDesc_t, std::array<MatrixLayout, 4>>;
MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc);
/********************** Utils **********************/
void setWorkspace(void* workspace);
void setFP32GemmConfig();
void setFP16GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
#ifdef ENABLE_BF16
void setBF16GemmConfig(cudaDataType_t outputType = CUDA_R_16BF);
#endif
#ifdef ENABLE_FP8
void setFP8GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
#endif
void setStream(cudaStream_t stream);
void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType);
CublasDataType getCublasDataType(cudaDataType_t data_type);
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
int const lda, int const ldb, int const ldc, int8_t fastAcc = 0);
void setScaleDescriptors(void* scale_a, void* scale_b);
void destroyDescriptors();
cublasHandle_t getCublasHandle()
{
return *(this->mCublasHandle);
}
cublasLtHandle_t getCublasLtHandle() const
{
return *(this->mCublasLtHandle);
}
};
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// We don't want to include cublas_api.h. It contains the CUBLAS_VER_* macro
// definition which is not sufficient to determine if we include cublas.h,
// cublas_v2.h or cublasLt.h.
#define TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) (MAJOR * 10000 + MINOR * 100 + PATCH)
#define TLLM_CUBLAS_VER_LE(MAJOR, MINOR, PATCH) \
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
<= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
#define TLLM_CUBLAS_VER_LT(MAJOR, MINOR, PATCH) \
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
< TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
#define TLLM_CUBLAS_VER_GE(MAJOR, MINOR, PATCH) \
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
>= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
#define TLLM_CUBLAS_VER_GT(MAJOR, MINOR, PATCH) \
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
> TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
namespace tensorrt_llm
{
namespace common
{
#ifdef ENABLE_BF16
inline __device__ float2 bf1622float2(const __nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = __low2float(val);
f_val.y = __high2float(val);
return f_val;
#else
return __bfloat1622float2(val);
#endif
}
inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16;
#else
val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.));
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16;
#endif
}
inline __device__ __nv_bfloat162 float22bf162(const float2 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __floats2bfloat162_rn(val.x, val.y);
#else
return __float22bfloat162_rn(val);
#endif
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162 val2;
val2.x = val;
val2.y = val;
return val2;
#else
return __bfloat162bfloat162(val);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
#else
return __hadd2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y));
#else
return __hadd(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
#else
return __hsub2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y));
#else
return __hsub(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
#else
return __hmul2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y));
#else
return __hmul(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh, fzl, fzh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
fzl = __low2float(z);
fzh = __high2float(z);
return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
#else
return __hfma2(x, y, z);
#endif
}
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
#else
return __hfma(x, y, z);
#endif
}
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh;
fxl = __low2float(x);
fxh = __high2float(x);
;
return __floats2bfloat162_rn(expf(fxl), expf(fxh));
#else
return h2exp(x);
#endif
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
__nv_bfloat162 t;
t.x = x;
t.y = y;
return t;
}
#endif
#endif
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
#else
return (__nv_bfloat16) ((float) a + (float) b + (float) c + (float) d);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
fdl = __low2float(d);
fdh = __high2float(d);
return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
#else
return a * b * c + d;
#endif
}
#endif // ENABLE_BF16
} // namespace common
} // namespace tensorrt_llm
// Operator definitions intentionally in global namespace
namespace
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return tensorrt_llm::common::bf16hmul2(x, y);
};
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return tensorrt_llm::common::bf16hadd2(x, y);
};
#endif
#endif
} // namespace
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define CUDA_LIB_NAME "cuda"
#if defined(_WIN32)
#include <windows.h>
#define dllOpen(name) LoadLibrary("nv" name ".dll")
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
#define dllGetSym(handle, name) static_cast<void*>(GetProcAddress(static_cast<HMODULE>(handle), name))
#else // For non-Windows platforms
#include <dlfcn.h>
#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY)
#define dllClose(handle) dlclose(handle)
#define dllGetSym(handle, name) dlsym(handle, name)
#endif // defined(_WIN32)
#include "cudaDriverWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
namespace tensorrt_llm::common
{
std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
{
static std::mutex mutex;
static std::weak_ptr<CUDADriverWrapper> instance;
std::shared_ptr<CUDADriverWrapper> result = instance.lock();
if (result)
{
return result;
}
std::lock_guard<std::mutex> lock(mutex);
result = instance.lock();
if (!result)
{
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
instance = result;
}
return result;
}
CUDADriverWrapper::CUDADriverWrapper()
: handle(dllOpen(CUDA_LIB_NAME))
{
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
auto load_sym = [](void* handle, char const* name)
{
void* ret = dllGetSym(handle, name);
return ret;
};
*reinterpret_cast<void**>(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
*reinterpret_cast<void**>(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage");
*reinterpret_cast<void**>(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
*reinterpret_cast<void**>(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
*reinterpret_cast<void**>(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
*reinterpret_cast<void**>(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
*reinterpret_cast<void**>(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
*reinterpret_cast<void**>(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
*reinterpret_cast<void**>(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
*reinterpret_cast<void**>(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
*reinterpret_cast<void**>(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
*reinterpret_cast<void**>(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
*reinterpret_cast<void**>(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
*reinterpret_cast<void**>(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
*reinterpret_cast<void**>(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
*reinterpret_cast<void**>(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
}
CUDADriverWrapper::~CUDADriverWrapper()
{
dllClose(handle);
}
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const
{
return (*_cuGetErrorName)(error, pStr);
}
CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const
{
return (*_cuGetErrorMessage)(error, pStr);
}
CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const
{
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
}
CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const
{
return (*_cuLinkComplete)(state, cubinOut, sizeOut);
}
CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const
{
return (*_cuModuleUnload)(hmod);
}
CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const
{
return (*_cuLinkDestroy)(state);
}
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const
{
return (*_cuModuleLoadData)(module, image);
}
CUresult CUDADriverWrapper::cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const
{
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
}
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const
{
return (*_cuModuleGetFunction)(hfunc, hmod, name);
}
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const
{
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
}
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path,
unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const
{
return (*_cuLaunchCooperativeKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams);
}
CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const
{
return (*_cuLaunchKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
}
CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const
{
return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides,
boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill);
}
CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const
{
return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount);
}
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUDA_DRIVER_WRAPPER_H
#define CUDA_DRIVER_WRAPPER_H
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
#include <memory>
#include <mutex>
namespace tensorrt_llm::common
{
class CUDADriverWrapper
{
public:
static std::shared_ptr<CUDADriverWrapper> getInstance();
~CUDADriverWrapper();
CUDADriverWrapper(CUDADriverWrapper const&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete;
CUDADriverWrapper(CUDADriverWrapper&&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete;
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
CUresult cuGetErrorMessage(CUresult error, char const** pStr) const;
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
CUresult cuModuleUnload(CUmodule hmod) const;
CUresult cuLinkDestroy(CUlinkState state) const;
CUresult cuModuleLoadData(CUmodule* module, void const* image) const;
CUresult cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const;
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const;
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions,
CUjit_option* options, void** optionValues) const;
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name,
unsigned int numOptions, CUjit_option* options, void** optionValues) const;
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const;
CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra) const;
CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank,
void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim,
cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const;
CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const;
private:
void* handle;
CUDADriverWrapper();
CUresult (*_cuGetErrorName)(CUresult, char const**);
CUresult (*_cuGetErrorMessage)(CUresult, char const**);
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
CUresult (*_cuModuleUnload)(CUmodule);
CUresult (*_cuLinkDestroy)(CUlinkState);
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
CUresult (*_cuModuleLoadData)(CUmodule*, void const*);
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*);
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*);
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLinkAddData)(
CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, CUstream, void**);
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra);
CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);
CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount);
};
template <typename T>
void checkDriver(
T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line)
{
if (result)
{
char const* errorName = nullptr;
char const* errorMsg = nullptr;
wrap.cuGetErrorName(result, &errorName);
wrap.cuGetErrorMessage(result, &errorMsg);
throw TllmException(
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg));
}
}
} // namespace tensorrt_llm::common
/*
* Macros compliant with TensorRT coding conventions
*/
#define TLLM_CU_CHECK(stat) \
do \
{ \
tensorrt_llm::common::checkDriver( \
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
} while (0)
#endif // CUDA_DRIVER_WRAPPER_H
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include <algorithm>
#include <cstdio>
#include <cuda_fp16.h>
#include <limits>
#include <type_traits>
namespace tensorrt_llm
{
namespace common
{
#ifdef ENABLE_FP8
constexpr int CTA_SIZE = 256;
template <bool QUANTIZE>
__inline__ __device__ float scale(float a, float b)
{
return QUANTIZE ? a / b : a * b;
}
template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
{
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
{
if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL)
{
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[i % lda])));
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
{
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[i / lda])));
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR)
{
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[0])));
}
}
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream)
{
dim3 grid(1024);
dim3 block(CTA_SIZE);
if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
scaleMatrix<QuantizeMode::PER_CHANNEL, true>
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TOKEN)
{
scaleMatrix<QuantizeMode::PER_TOKEN, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
scaleMatrix<QuantizeMode::PER_TENSOR, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
sync_check_cuda_error();
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream)
{
dim3 grid(1024);
dim3 block(CTA_SIZE);
if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
scaleMatrix<QuantizeMode::PER_CHANNEL, false>
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TOKEN)
{
scaleMatrix<QuantizeMode::PER_TOKEN, false><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
scaleMatrix<QuantizeMode::PER_TENSOR, false>
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
sync_check_cuda_error();
}
template <typename T_FAKE, typename T_OUT, typename T_IN>
__global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel)
{
for (int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < numel; tid += blockDim.x * gridDim.x)
{
T_FAKE tmp = (T_FAKE) (static_cast<float>(src[tid]));
dst[tid] = (T_OUT) (static_cast<float>(tmp));
}
}
template <typename T_FAKE, typename T_OUT, typename T_IN>
void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream)
{
fakeQuantize<T_FAKE><<<1024, CTA_SIZE, 0, stream>>>(dst, src, numel);
sync_check_cuda_error();
}
template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>(
float* dst, float const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<float, float, __nv_fp8_e4m3>(
float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>(
half* dst, half const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>(
__nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<float, half, float>(
half* dst, float const* src, const int64_t numel, cudaStream_t stream);
__device__ float atomicMaxExtd(float* address, float val)
{
assert(val >= 0);
unsigned int* address_as_u = reinterpret_cast<unsigned int*>(address);
unsigned int old = atomicMax(address_as_u, __float_as_uint(val));
return __uint_as_float(old);
}
template <typename T>
inline __device__ T atomicMaxExtdV2(T* address, T val)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
static_assert(std::is_same_v<T, half> | std::is_same_v<T, __nv_bfloat16>, "T needs to be either half or bfloat16");
// The address in 64 bits.
uint64_t address_u64 = reinterpret_cast<uint64_t const&>(address);
// Pack the input value into 32 bits.
union
{
T v[2];
uint16_t u[2];
} old, tmp = {};
int const loc = (address_u64 & 0x2) >> 1;
tmp.v[loc] = val;
// 4B aligned pointer.
auto aligned_address = reinterpret_cast<T*>(address_u64 & ~0x3ull);
if constexpr (std::is_same_v<T, half>)
{
asm volatile("atom.global.v2.f16.max.noftz {%0, %1}, [%2], {%3, %4};"
: "=h"(old.u[0]), "=h"(old.u[1])
: "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1]));
}
if constexpr (std::is_same_v<T, __nv_bfloat16>)
{
asm volatile("atom.global.v2.bf16.max.noftz {%0, %1}, [%2], {%3, %4};"
: "=h"(old.u[0]), "=h"(old.u[1])
: "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1]));
}
// Return the correct half.
return old.v[loc];
#endif
}
__device__ half atomicMaxExtd(half* address, half val)
{
unsigned short int* address_as_u = reinterpret_cast<unsigned short int*>(address);
unsigned short int old = *address_as_u, assumed;
while (val > __ushort_as_half(old))
{
assumed = old;
old = atomicCAS(address_as_u, assumed, __half_as_ushort(val));
}
return __ushort_as_half(old);
}
__device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
unsigned short int* address_as_u = reinterpret_cast<unsigned short int*>(address);
unsigned short int old = *address_as_u, assumed;
while (val > __ushort_as_bfloat16(old))
{
assumed = old;
old = atomicCAS(address_as_u, assumed, __bfloat16_as_ushort(val));
}
return __ushort_as_bfloat16(old);
#else
assert(0);
asm volatile("brkpt;\n" ::);
return __nv_bfloat16(0);
#endif
}
template <QuantizeMode QUANTIZE_MODE, typename T_S, typename T_W>
__global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t size, const int64_t n)
{
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL)
{
for (int64_t col = threadIdx.x; col < n; col += blockDim.x)
{
float max = 0.f;
for (int64_t i = col + n * blockIdx.x; i < size; i += gridDim.x * n)
{
auto val = fabs(static_cast<float>(weights[i]));
max = max > val ? max : val;
}
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if constexpr (std::is_same_v<T_S, float>)
{
atomicMaxExtd(quant_ptr + col, scale);
}
else
{
auto const address_u64 = reinterpret_cast<uint64_t>(quant_ptr + col);
if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0))
atomicMaxExtd(quant_ptr + col, scale);
else
atomicMaxExtdV2(quant_ptr + col, scale);
}
#else // Vector atomics require __CUDA_ARCH__ >= 900
atomicMaxExtd(quant_ptr + col, scale);
#endif
}
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
{
auto const nrows = size / n;
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
{
float max = 0.f;
for (int64_t i = threadIdx.x; i < n; i += blockDim.x)
{
auto val = fabs(static_cast<float>(weights[row * n + i]));
max = max > val ? max : val;
}
max = blockReduceMax<float>(max);
if (threadIdx.x == 0)
{
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
quant_ptr[row] = scale;
}
}
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR)
{
float max = 0.f;
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x)
{
auto val = fabs(static_cast<float>(weights[i]));
max = max > val ? max : val;
}
max = blockReduceMax<float>(max);
if (threadIdx.x == 0)
{
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
atomicMaxExtd(quant_ptr, scale);
}
}
}
template <typename T_S, typename T_W>
void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t numel, const int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream)
{
if (quantize_mode == QuantizeMode::PER_TOKEN)
{
dim3 block(CTA_SIZE);
dim3 grid(numel / lda);
computeFP8QuantizeScale<QuantizeMode::PER_TOKEN><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
dim3 block(CTA_SIZE);
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_CHANNEL><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
dim3 block(1024);
dim3 grid(1024);
cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_TENSOR><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
}
sync_check_cuda_error();
}
#define DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(type_scale, type_in) \
template void invokeComputeFP8QuantizeScale<type_scale, type_in>(type_scale * input_scale, type_in const* weights, \
int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(half, half);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, half);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, float);
#ifdef ENABLE_BF16
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(__nv_bfloat16, __nv_bfloat16);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, __nv_bfloat16);
#endif
template <typename T_OUT, typename T_S, typename T_IN>
__global__ void dynamicQuantizeMatrixPerToken(
T_OUT* output, T_S* quant_ptr, T_IN const* input, int64_t numel, int64_t lda)
{
extern __shared__ __align__(sizeof(float)) char _shmem[];
T_IN* shmem = reinterpret_cast<T_IN*>(_shmem);
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
auto const nrows = numel / lda;
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
{
float max = 0.f;
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
{
auto const in = input[row * lda + i];
shmem[i] = in;
auto val = fabs(static_cast<float>(in));
max = max > val ? max : val;
}
max = blockAllReduceMax<float>(max); // __syncthreads() called so we can read shmem
auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
{
// true means we are quantizing
output[row * lda + i] = (T_OUT) scale<true>(static_cast<float>(shmem[i]), static_cast<float>(s));
}
if (threadIdx.x == 0)
{
quant_ptr[row] = s;
}
}
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* input, const int64_t numel,
const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream)
{
if (quantize_mode == QuantizeMode::PER_TOKEN)
{
dim3 grid(numel / lda);
bool use_shmem = true;
auto const shmem_size = lda * sizeof(T_IN);
if (shmem_size >= (48 << 10))
{
cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken<T_OUT, T_S, T_IN>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
use_shmem = ret == cudaSuccess;
}
if (use_shmem)
{
// ensure the threadblock is as large as possible to increase occupancy
dim3 block(std::min((lda + 31) / 32 * 32, static_cast<int64_t>(1024)));
dynamicQuantizeMatrixPerToken<<<grid, block, shmem_size, stream>>>(output, quant_ptr, input, numel, lda);
}
else
{
dim3 block(CTA_SIZE);
computeFP8QuantizeScale<QuantizeMode::PER_TOKEN><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
sync_check_cuda_error();
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
}
}
else if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
dim3 block(CTA_SIZE);
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_CHANNEL><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
sync_check_cuda_error();
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
dim3 block(1024);
dim3 grid(1024);
cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_TENSOR><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
sync_check_cuda_error();
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
}
sync_check_cuda_error();
}
#define DEFINE_INVOKE_QUANTIZE_MATRIX(type_out, type_scale, type_in) \
template void invokeQuantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
cudaStream_t stream); \
template void invokeDequantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
cudaStream_t stream); \
template void invokeComputeScalesAndQuantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
type_scale * input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
cudaStream_t stream);
#ifdef ENABLE_FP8
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, float);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, half);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, half, half);
DEFINE_INVOKE_QUANTIZE_MATRIX(half, half, __nv_fp8_e4m3);
DEFINE_INVOKE_QUANTIZE_MATRIX(float, float, __nv_fp8_e4m3);
DEFINE_INVOKE_QUANTIZE_MATRIX(half, float, __nv_fp8_e4m3);
#ifdef ENABLE_BF16
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3);
#endif
#endif
#endif // ENABLE_FP8
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_FP8
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <stdint.h>
#define FP8_MHA
#define FUSE_GEMM_ACT
#define FP8_GEMM_OUTPUT_QUANT_DISABLE
#ifdef FUSE_GEMM_ACT
#define USE_QGMMA
#endif
namespace tensorrt_llm
{
namespace common
{
constexpr float FP8_E4M3_MAX = 448.0f;
enum QuantizeMode
{
PER_CHANNEL,
PER_TENSOR,
PER_CHANNEL_WEIGHT_PER_TENSOR_ACT,
PER_TOKEN,
};
// Packed Data Type
typedef struct __CUDA_ALIGN__(32)
{
float array[8];
} float8;
typedef struct __CUDA_ALIGN__(16)
{
half array[8];
} half8;
typedef struct __CUDA_ALIGN__(8)
{
half2 array[2];
} half2_2;
typedef struct __CUDA_ALIGN__(8)
{
half array[4];
} half_4;
#ifdef ENABLE_BF16
typedef struct __CUDA_ALIGN__(4)
{
__nv_bfloat16 array[2];
} __nv_bfloat16_2;
typedef struct __CUDA_ALIGN__(8)
{
__nv_bfloat162 x, y;
} __nv_bfloat162_2_xy;
typedef struct __CUDA_ALIGN__(8)
{
__nv_bfloat16 array[4];
} __nv_bfloat164;
typedef struct __CUDA_ALIGN__(8)
{
__nv_bfloat162 array[2];
} __nv_bfloat162_2;
typedef struct __CUDA_ALIGN__(16)
{
__nv_bfloat16 array[8];
} __nv_bfloat168;
typedef struct __CUDA_ALIGN__(16)
{
__nv_bfloat162 array[4];
} __nv_bfloat162_4;
typedef struct __CUDA_ALIGN__(32)
{
__nv_bfloat16 array[16];
} __nv_bfloat1616;
#endif
#ifdef ENABLE_FP8
typedef struct __CUDA_ALIGN__(2)
{
__nv_fp8_e4m3 array[2];
} __nv_fp8_2_e4m3;
typedef struct __CUDA_ALIGN__(4)
{
__nv_fp8_e4m3 array[4];
} __nv_fp8_4_e4m3;
typedef struct __CUDA_ALIGN__(4)
{
__nv_fp8x2_e4m3 array[2];
} __nv_fp8x2_x2_e4m3;
typedef struct __CUDA_ALIGN__(8)
{
__nv_fp8_e4m3 array[8];
} __nv_fp8_8_e4m3;
typedef struct __CUDA_ALIGN__(8)
{
__nv_fp8x2_e4m3 array[4];
} __nv_fp8x2_x4_e4m3;
typedef struct __CUDA_ALIGN__(16)
{
__nv_fp8_e4m3 array[16];
} __nv_fp8x16_e4m3;
#endif
// only BF16 and FP8
template <typename T, int PACK_SIZE>
struct PackType
{
using type = float;
};
#ifdef ENABLE_BF16
template <>
struct PackType<__nv_bfloat16, 2>
{
using type = __nv_bfloat16_2;
};
template <>
struct PackType<__nv_bfloat16, 4>
{
using type = __nv_bfloat164;
};
template <>
struct PackType<__nv_bfloat16, 8>
{
using type = __nv_bfloat168;
};
#endif
#ifdef ENABLE_FP8
template <>
struct PackType<__nv_fp8_e4m3, 2>
{
using type = __nv_fp8_2_e4m3;
};
template <>
struct PackType<__nv_fp8_e4m3, 4>
{
using type = __nv_fp8_4_e4m3;
};
template <>
struct PackType<__nv_fp8_e4m3, 8>
{
using type = __nv_fp8_8_e4m3;
};
#endif
__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, __nv_fp8x4_e4m3 const* in)
{
const char4 tmp_val = reinterpret_cast<char4 const*>(in)[0];
*out1 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
*out2 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0],
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]);
}
__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(__nv_fp8x2_e4m3 const* in)
{
const char2 tmp_val = reinterpret_cast<char2 const*>(in)[0];
__nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
return out;
}
__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, __nv_fp8x4_e4m3 const* in)
{
const char4 tmp_val = reinterpret_cast<char4 const*>(in)[0];
*out1 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
*out2 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0],
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]);
}
__inline__ __device__ half2 fp8x2_e4m3_to_half2(__nv_fp8x2_e4m3 const* in)
{
const char2 tmp_val = reinterpret_cast<char2 const*>(in)[0];
half2 out = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
return out;
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeQuantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream);
template <typename T_OUT, typename T_S, typename T_IN>
void invokeDequantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream);
template <typename T_FAKE, typename T_OUT, typename T_IN>
void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream);
template <typename T_S, typename T_W>
void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t k, const int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream);
template <typename T_OUT, typename T_S, typename T_IN>
void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* weights, const int64_t numel,
const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream);
} // namespace common
} // namespace tensorrt_llm
#endif // ENABLE_FP8
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include <assert.h>
#include <cuda.h>
#include <cuda_fp16.h>
#if ENABLE_BF16
#include <cuda_bf16.h>
#endif
namespace tensorrt_llm
{
namespace common
{
template <typename T>
inline __device__ T ldg(T const* val)
{
return __ldg(val);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
template <>
inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
#endif // ENABLE_BF16
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter
{
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2>
{
using Type = half;
};
template <>
struct TypeConverter<half>
{
using Type = half2;
};
#if ENABLE_BF16
template <>
struct TypeConverter<__nv_bfloat162>
{
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16>
{
using Type = __nv_bfloat162;
};
#endif // ENABLE_BF16
// Defined math operations (bfloat16 fallback to fp32 when it is not supported)
template <typename T>
inline __device__ T hadd2(T a, T b)
{
return __hadd2(a, b);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T add(T a, T b)
{
return a + b;
}
template <>
inline __device__ half2 add(half2 a, half2 b)
{
return __hadd2(a, b);
}
template <>
inline __device__ half add(half a, half b)
{
return __hadd(a, b);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b);
}
template <>
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
{
return bf16hadd(a, b);
}
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b)
{
return bf16hadd(a, __float2bfloat16(b));
}
#endif // ENABLE_BF16
// applies to all 4 values addition
template <typename T>
inline __device__ T add(T a, T b, T c)
{
return a + b + c;
}
#if ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hadd(a, b, c);
}
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hadd2(a, b, c);
}
#endif // ENABLE_BF16
// applies to all 4 values addition
template <typename T>
inline __device__ T add(T a, T b, T c, T d)
{
return (T) ((float) a + (float) b + (float) c + (float) d);
}
#if ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
return bf16hadd(a, b, c, d);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T hsub2(T a, T b)
{
return __hsub2(a, b);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hsub2(a, b);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T hmul2(T a, T b)
{
return __hmul2(a, b);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hmul2(a, b);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T hmul2(T a, T b, T c)
{
return a * b * c;
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hmul2(a, b, c);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T mul(T a, T b, T c)
{
return a * b * c;
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hmul(a, b, c);
}
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hmul2(a, b, c);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T fma(T a, T b, T c, T d)
{
return a * b * c + d;
}
#if ENABLE_BF16
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
return bf16hfma2(a, b, c, d);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T fma(T a, T b, T c)
{
return a * b + c;
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hfma2(a, b, c);
}
template <>
inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hfma(a, b, c);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T hexp2(T a)
{
return h2exp(a);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a)
{
return bf16exp2(a);
}
#endif // ENABLE_BF16
template <typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val)
{
return val;
}
template <>
__device__ inline float2 cuda_cast<float2, int2>(int2 val)
{
return make_float2(val.x, val.y);
}
template <>
__device__ inline float2 cuda_cast<float2, float>(float val)
{
return make_float2(val, val);
}
template <>
__device__ inline float2 cuda_cast<float2, half2>(half2 val)
{
return __half22float2(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float2>(float2 val)
{
return __float22half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float>(float val)
{
return __float2half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, half>(half val)
{
return __half2half2(val);
}
template <>
__device__ inline int8_t cuda_cast<int8_t, half>(half val)
{
union
{
int8_t int8[2];
int16_t int16;
};
union
{
half fp16;
int16_t int16_in;
};
fp16 = val;
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline int8_t cuda_cast<int8_t, float>(float val)
{
union
{
int8_t int8[2];
int16_t int16;
};
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_half2(int8[0], int8[1]);
}
template <>
__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_float2(int8[0], int8[1]);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_cast(int32_t val)
{
return static_cast<float>(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast(int8_t val)
{
return static_cast<float>(val);
}
template <>
__device__ inline int8_t cuda_cast(__nv_bfloat16 val)
{
return static_cast<float>(val);
}
template <>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val)
{
return __bfloat162float(val);
}
template <>
__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622float2(val);
}
template <>
__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val)
{
return __float2half(__bfloat162float(val));
}
template <>
__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622int16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val)
{
return __float2bfloat16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val)
{
return __float2bfloat16(__half2float(val));
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val)
{
return bf162bf162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val)
{
return __float2bfloat162_rn(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val)
{
return float22bf162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
__nv_bfloat162 res;
res.x = cuda_cast<__nv_bfloat16>(int8[0]);
res.y = cuda_cast<__nv_bfloat16>(int8[1]);
return res;
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
{
return float22bf162(__half22float2(val));
}
#endif // ENABLE BF16
template <typename T>
__device__ inline T cuda_abs(T val)
{
assert(false);
return {};
}
template <>
__device__ inline float cuda_abs(float val)
{
return fabs(val);
}
template <>
__device__ inline float2 cuda_abs(float2 val)
{
return make_float2(fabs(val.x), fabs(val.y));
}
template <>
__device__ inline half cuda_abs(half val)
{
return __habs(val);
}
template <>
__device__ inline half2 cuda_abs(half2 val)
{
return __habs2(val);
}
#ifdef ENABLE_BF16
#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
template <>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
return __habs(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
return __habs2(val);
}
#endif
#endif // ENABLE_FP16
template <typename To, typename Ti>
__device__ inline To cuda_sum(Ti val)
{
return cuda_cast<To>(val);
};
template <typename To>
__device__ inline To cuda_sum(float2 val)
{
return cuda_cast<To>(val.x + val.y);
};
// Unary maximum: compute the max of a vector type
template <typename To, typename Ti>
__device__ inline To cuda_max(Ti val)
{
return cuda_cast<To>(val);
};
template <>
__device__ inline float cuda_max(float2 val)
{
return fmaxf(val.x, val.y);
}
template <>
__device__ inline half cuda_max(half2 val)
{
return __hmax(val.x, val.y);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __hmax(val.x, val.y);
#else
assert(0);
asm volatile("brkpt;\n" ::);
return __nv_bfloat16(0);
#endif
}
#endif
// Binary maximum: compute the max of two values.
template <typename T>
__device__ inline T cuda_max(T val1, T val2)
{
return (val1 > val2) ? val1 : val2;
}
template <>
__device__ inline float2 cuda_max(float2 val1, float2 val2)
{
float2 out;
out.x = fmaxf(val1.x, val2.x);
out.y = fmaxf(val1.y, val2.y);
return out;
}
template <>
__device__ inline half2 cuda_max(half2 val1, half2 val2)
{
return __hmax2(val1, val2);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat162 cuda_max(__nv_bfloat162 val1, __nv_bfloat162 val2)
{
return __hmax2(val1, val2);
}
#endif // ENABLE_BF16
// Binary maximum: compute the min of two values.
template <typename T>
__device__ inline T cuda_min(T val1, T val2)
{
return (val1 < val2) ? val1 : val2;
}
template <>
__device__ inline float2 cuda_min(float2 val1, float2 val2)
{
float2 out;
out.x = fminf(val1.x, val2.x);
out.y = fminf(val1.y, val2.y);
return out;
}
template <>
__device__ inline half2 cuda_min(half2 val1, half2 val2)
{
return __hmin2(val1, val2);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat162 cuda_min(__nv_bfloat162 val1, __nv_bfloat162 val2)
{
return __hmin2(val1, val2);
}
#endif // ENABLE_BF16
// Helper function of clamping the val into the given range.
template <typename T>
inline __device__ T cuda_clamp(T val, T minVal, T maxVal)
{
return cuda_min(cuda_max(val, minVal), maxVal);
}
#ifdef ENABLE_FP8
template <>
__device__ inline float2 cuda_cast<float2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
return bf1622float2(fp8x2_e4m3_to_bfloat2(&val));
}
template <>
__device__ inline half2 cuda_cast<half2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
return fp8x2_e4m3_to_half2(&val);
}
template <>
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val)
{
return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val)));
}
template <>
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, half2>(half2 val)
{
return __nv_fp8x2_e4m3(cuda_cast<float2>(val));
}
template <>
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, __nv_bfloat162>(__nv_bfloat162 val)
{
return __nv_fp8x2_e4m3(cuda_cast<float2>(val));
}
template <>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val)
{
return __nv_fp8_e4m3(val);
}
template <>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val)
{
return __nv_fp8_e4m3(val);
}
template <>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val)
{
return __nv_fp8_e4m3(val);
}
template <>
__device__ inline float cuda_cast<float, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
{
return (float) val;
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
return fp8x2_e4m3_to_bfloat2(&val);
}
template <>
__device__ inline int8_t cuda_cast<int8_t, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
{
// no impl
return 0;
}
template <>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val)
{
return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast<float>(val)));
}
#endif // ENABLE_FP8
} // namespace common
} // namespace tensorrt_llm
This diff is collapsed.
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/tllmException.h"
#include <cuda_runtime.h>
namespace tensorrt_llm::common
{
Logger::Logger()
{
char* isFirstRankOnlyChar = std::getenv("TLLM_LOG_FIRST_RANK_ONLY");
bool isFirstRankOnly = (isFirstRankOnlyChar != nullptr && std::string(isFirstRankOnlyChar) == "ON");
auto const* levelName = std::getenv("TLLM_LOG_LEVEL");
if (levelName != nullptr)
{
auto level = [levelName = std::string(levelName)]()
{
if (levelName == "TRACE")
return TRACE;
if (levelName == "DEBUG")
return DEBUG;
if (levelName == "INFO")
return INFO;
if (levelName == "WARNING")
return WARNING;
if (levelName == "ERROR")
return ERROR;
TLLM_THROW("Invalid log level: %s", levelName.c_str());
}();
// If TLLM_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR
if (isFirstRankOnly)
{
auto const deviceId = getDevice();
if (deviceId != 1)
{
level = ERROR;
}
}
setLevel(level);
}
}
void Logger::log(std::exception const& ex, Logger::Level level)
{
log(level, "%s: %s", TllmException::demangle(typeid(ex).name()).c_str(), ex.what());
}
Logger* Logger::getLogger()
{
thread_local Logger instance;
return &instance;
}
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdlib>
#include <iostream>
#include <stdexcept>
#include <string>
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/stringUtils.h"
namespace tensorrt_llm::common
{
class Logger
{
// On Windows, the file wingdi.h is included which has
// #define ERROR 0
// This breaks everywhere ERROR is used in the Level enum
#ifdef _WIN32
#undef ERROR
#endif // _WIN32
public:
enum Level
{
TRACE = 0,
DEBUG = 10,
INFO = 20,
WARNING = 30,
ERROR = 40
};
static Logger* getLogger();
Logger(Logger const&) = delete;
void operator=(Logger const&) = delete;
#if defined(_MSC_VER)
template <typename... Args>
void log(Level level, char const* format, Args const&... args);
template <typename... Args>
void log(Level level, int rank, char const* format, Args const&... args);
#else
template <typename... Args>
void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0)));
template <typename... Args>
void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0)));
#endif
template <typename... Args>
void log(Level level, std::string const& format, Args const&... args)
{
return log(level, format.c_str(), args...);
}
template <typename... Args>
void log(Level const level, int const rank, std::string const& format, Args const&... args)
{
return log(level, rank, format.c_str(), args...);
}
void log(std::exception const& ex, Level level = Level::ERROR);
Level getLevel() const
{
return level_;
}
void setLevel(Level const level)
{
level_ = level;
log(INFO, "Set logger level to %s", getLevelName(level));
}
bool isEnabled(Level const level) const
{
return level_ <= level;
}
private:
static auto constexpr kPREFIX = "[TensorRT-LLM]";
#ifndef NDEBUG
Level const DEFAULT_LOG_LEVEL = DEBUG;
#else
Level const DEFAULT_LOG_LEVEL = INFO;
#endif
Level level_ = DEFAULT_LOG_LEVEL;
Logger(); // NOLINT(modernize-use-equals-delete)
static inline char const* getLevelName(Level const level)
{
switch (level)
{
case TRACE: return "TRACE";
case DEBUG: return "DEBUG";
case INFO: return "INFO";
case WARNING: return "WARNING";
case ERROR: return "ERROR";
}
TLLM_THROW("Unknown log level: %d", level);
}
static inline std::string getPrefix(Level const level)
{
return fmtstr("%s[%s] ", kPREFIX, getLevelName(level));
}
static inline std::string getPrefix(Level const level, int const rank)
{
return fmtstr("%s[%s][%d] ", kPREFIX, getLevelName(level), rank);
}
};
template <typename... Args>
void Logger::log(Logger::Level level, char const* format, Args const&... args)
{
if (isEnabled(level))
{
auto const fmt = getPrefix(level) + format;
auto& out = level_ < WARNING ? std::cout : std::cerr;
if constexpr (sizeof...(args) > 0)
{
out << fmtstr(fmt.c_str(), args...);
}
else
{
out << fmt;
}
out << std::endl;
}
}
template <typename... Args>
void Logger::log(Logger::Level const level, int const rank, char const* format, Args const&... args)
{
if (isEnabled(level))
{
auto const fmt = getPrefix(level, rank) + format;
auto& out = level_ < WARNING ? std::cout : std::cerr;
if constexpr (sizeof...(args) > 0)
{
out << fmtstr(fmt.c_str(), args...);
}
else
{
out << fmt;
}
out << std::endl;
}
}
#define TLLM_LOG(level, ...) \
do \
{ \
auto* const logger = tensorrt_llm::common::Logger::getLogger(); \
if (logger->isEnabled(level)) \
{ \
logger->log(level, __VA_ARGS__); \
} \
} while (0)
#define TLLM_LOG_TRACE(...) TLLM_LOG(tensorrt_llm::common::Logger::TRACE, __VA_ARGS__)
#define TLLM_LOG_DEBUG(...) TLLM_LOG(tensorrt_llm::common::Logger::DEBUG, __VA_ARGS__)
#define TLLM_LOG_INFO(...) TLLM_LOG(tensorrt_llm::common::Logger::INFO, __VA_ARGS__)
#define TLLM_LOG_WARNING(...) TLLM_LOG(tensorrt_llm::common::Logger::WARNING, __VA_ARGS__)
#define TLLM_LOG_ERROR(...) TLLM_LOG(tensorrt_llm::common::Logger::ERROR, __VA_ARGS__)
#define TLLM_LOG_EXCEPTION(ex, ...) tensorrt_llm::common::Logger::getLogger()->log(ex, ##__VA_ARGS__)
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <float.h>
namespace tensorrt_llm
{
namespace common
{
template <typename T>
struct QuantTypeStaticVals;
template <>
struct QuantTypeStaticVals<int8_t>
{
static constexpr float MAX_VAL = 127.f;
static constexpr float MIN_SCALING_FACTOR = 0.f;
static constexpr float MIN_SCALING_FACTOR_RCP = FLT_MAX;
};
#ifdef ENABLE_FP8
template <>
struct QuantTypeStaticVals<__nv_fp8_e4m3>
{
static constexpr float MAX_VAL = 448.f;
// Ref: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L720
static constexpr float MIN_SCALING_FACTOR = 1.0f / (448.f * 512.f);
static constexpr float MIN_SCALING_FACTOR_RCP = (448.f * 512.f);
};
#endif // ENABLE_FP8
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <optional>
#include <string>
namespace tensorrt_llm
{
namespace common
{
class QuantMode
{
// [WARNING] KEEP BELOW DEFINITION IN SYNC WITH tensorrt_llm/quantization/mode.py
public:
using BaseType = std::uint32_t;
explicit constexpr QuantMode(BaseType value) noexcept
: mValue{value}
{
}
QuantMode() noexcept = default;
constexpr QuantMode(QuantMode const&) noexcept = default;
constexpr QuantMode& operator=(QuantMode const& other) noexcept = default;
static constexpr QuantMode none() noexcept
{
return QuantMode(BaseType(0));
}
static constexpr QuantMode int4Weights() noexcept
{
return QuantMode(BaseType(1u) << 0);
}
static constexpr QuantMode int8Weights() noexcept
{
return QuantMode(BaseType(1u) << 1);
}
static constexpr QuantMode activations() noexcept
{
return QuantMode(BaseType(1u) << 2);
}
static constexpr QuantMode perChannelScaling() noexcept
{
return QuantMode(BaseType(1u) << 3);
}
static constexpr QuantMode perTokenScaling() noexcept
{
return QuantMode(BaseType(1u) << 4);
}
static constexpr QuantMode perGroupScaling() noexcept
{
return QuantMode(BaseType(1u) << 5);
}
static constexpr QuantMode int8KvCache() noexcept
{
return QuantMode(BaseType(1u) << 6);
}
static constexpr QuantMode fp8KvCache() noexcept
{
return QuantMode(BaseType(1u) << 7);
}
static constexpr QuantMode fp8Qdq() noexcept
{
return QuantMode(BaseType(1u) << 8);
}
static constexpr QuantMode fp8RowWise() noexcept
{
return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9);
}
static constexpr QuantMode w4a8QServe() noexcept
{
return QuantMode(BaseType(1u) << 10);
}
constexpr BaseType value() const noexcept
{
return mValue;
}
constexpr bool isSet(QuantMode const& mode) const noexcept
{
return (mValue & mode.value()) == mode.value();
}
constexpr bool hasInt4Weights() const noexcept
{
return isSet(int4Weights());
}
constexpr bool hasInt8Weights() const noexcept
{
return isSet(int8Weights());
}
constexpr bool hasActivations() const noexcept
{
return isSet(activations());
}
constexpr bool hasPerChannelScaling() const noexcept
{
return isSet(perChannelScaling());
}
constexpr bool hasPerTokenScaling() const noexcept
{
return isSet(perTokenScaling());
}
constexpr bool hasPerGroupScaling() const noexcept
{
return isSet(perGroupScaling());
}
constexpr bool hasStaticActivationScaling() const noexcept
{
return !hasPerTokenScaling();
}
constexpr bool hasInt8KvCache() const noexcept
{
return isSet(int8KvCache());
}
constexpr bool hasFp8KvCache() const noexcept
{
return isSet(fp8KvCache());
}
constexpr bool hasFp8Qdq() const noexcept
{
return isSet(fp8Qdq());
}
constexpr bool hasFp8RowWise() const noexcept
{
return isSet(fp8RowWise());
}
constexpr bool hasKvCacheQuant() const noexcept
{
return hasInt8KvCache() || hasFp8KvCache();
}
static constexpr QuantMode fromDescription(bool quantizeWeights = false, bool quantizeActivations = false,
bool perToken = false, bool perChannel = false, bool perGroup = false, bool useInt4Weights = false,
bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false,
bool useW4a8QServe = false)
{
QuantMode quantMode{};
if (quantizeWeights)
{
if (useInt4Weights)
quantMode += int4Weights();
else
quantMode += int8Weights();
}
if (quantizeActivations)
{
quantMode += activations();
}
if (perChannel)
{
quantMode += QuantMode::perChannelScaling();
}
if (perToken)
{
quantMode += QuantMode::perTokenScaling();
}
if (perGroup)
{
quantMode += QuantMode::perGroupScaling();
}
if (useInt8KvCache)
{
quantMode += int8KvCache();
}
if (useFp8KvCache)
{
quantMode += fp8KvCache();
}
if (useFp8Qdq)
{
quantMode += fp8Qdq();
}
if (useFp8RowWise)
{
quantMode += fp8RowWise();
}
if (useW4a8QServe)
{
quantMode += w4a8QServe();
}
return quantMode;
}
static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false)
{
return fromDescription(true, true, perToken, perChannel);
}
static constexpr QuantMode useQServe(bool perGroup)
{
return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true);
}
static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false)
{
return fromDescription(true, false, false, false, perGroup, useInt4Weights);
}
static QuantMode const fromQuantAlgo(
std::optional<std::string> quantAlgo = std::nullopt, std::optional<std::string> kvCacheQuantAlgo = std::nullopt)
{
QuantMode quantMode{};
if (quantAlgo == "W8A16")
{
quantMode = useWeightOnly(false, false);
}
else if (quantAlgo == "W4A16")
{
quantMode = useWeightOnly(true, false);
}
else if (quantAlgo == "W4A16_AWQ")
{
quantMode = useWeightOnly(true, true);
}
else if (quantAlgo == "W4A8_AWQ")
{
quantMode = useWeightOnly(true, true);
}
else if (quantAlgo == "W4A8_QSERVE_PER_GROUP")
{
quantMode = useQServe(false);
}
else if (quantAlgo == "W4A8_QSERVE_PER_CHANNEL")
{
quantMode = useQServe(true);
}
else if (quantAlgo == "W4A16_GPTQ")
{
quantMode = useWeightOnly(true, true);
}
else if (quantAlgo == "W8A8_SQ_PER_CHANNEL")
{
quantMode = useSmoothQuant(false, true);
}
else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PLUGIN")
{
quantMode = useSmoothQuant(false, false);
}
else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN")
{
quantMode = useSmoothQuant(true, true);
}
else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN")
{
quantMode = useSmoothQuant(false, true);
}
else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN")
{
quantMode = useSmoothQuant(true, false);
}
else if (quantAlgo == "FP8")
{
quantMode = fromDescription(false, false, false, false, false, false, false, false, true);
}
else if (quantAlgo == "FP8_ROWWISE")
{
quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true);
}
if (kvCacheQuantAlgo == "INT8")
{
quantMode += int8KvCache();
}
else if (kvCacheQuantAlgo == "FP8")
{
quantMode += fp8KvCache();
}
return quantMode;
}
constexpr QuantMode operator+(QuantMode const& other) const noexcept
{
return QuantMode(mValue | other.mValue);
}
constexpr QuantMode& operator+=(QuantMode const& other) noexcept
{
return *this = *this + other;
}
constexpr QuantMode operator-(QuantMode const& other) const noexcept
{
return QuantMode(mValue & ~other.mValue);
}
constexpr QuantMode& operator-=(QuantMode const& other) noexcept
{
return *this = *this - other;
}
constexpr bool operator==(QuantMode const& other) const noexcept
{
return mValue == other.mValue;
}
constexpr bool operator!=(QuantMode const& other) const noexcept
{
return !(*this == other);
}
private:
BaseType mValue{0};
};
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <array>
#include <assert.h>
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
#include <cooperative_groups/reduce.h>
#else
#include <cooperative_groups.h>
#endif
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <float.h>
#include <type_traits>
namespace cg = cooperative_groups;
namespace tensorrt_llm
{
namespace common
{
template <int VPT>
struct BytesToType;
template <>
struct BytesToType<1>
{
using type = uint8_t;
};
template <>
struct BytesToType<2>
{
using type = uint16_t;
};
template <>
struct BytesToType<4>
{
using type = uint32_t;
};
template <>
struct BytesToType<8>
{
using type = uint64_t;
};
template <>
struct BytesToType<16>
{
using type = float4;
};
template <int Bytes>
__device__ inline void copy(void const* local, void* data)
{
using T = typename BytesToType<Bytes>::type;
T const* in = static_cast<T const*>(local);
T* out = static_cast<T*>(data);
*out = *in;
}
static float constexpr HALF_FLT_MAX = 65504.F;
#define FINAL_MASK 0xffffffff
template <typename T>
__inline__ __device__ T warpReduceSum(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
return val;
}
/* Calculate the sum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T) (0.0f);
val = warpReduceSum<T>(val);
return val;
}
template <typename T>
__inline__ __device__ T warpReduceMax(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
return val;
}
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceMax(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
val = warpReduceMax(val);
return val;
}
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockAllReduceMax(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
val = warpReduceMax(val);
return val;
}
template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T* val)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T) (0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T* val)
{
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val);
if (lane == 0)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++)
{
val[i] = is_mask ? shared[i][lane] : (T) (0.0f);
}
warpReduceSumV2<T, NUM>(val);
return (T) 0.0f;
}
template <typename T, int NUM>
__inline__ __device__ T warpReduceMaxV2(T* val)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
}
return (T) (0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceMaxV2(T* val)
{
static __shared__ T shared[32][NUM];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
warpReduceMaxV2<T, NUM>(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
shared[wid][i] = val[i];
}
}
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++)
{
val[i] = is_mask ? shared[lane][i] : (T) -1e20f;
}
warpReduceMaxV2<T, NUM>(val);
return (T) 0.0f;
}
template <int NUM>
__inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cgBlockReduceSumElements_shm)
{
cg::thread_block cta = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);
int const tid = cta.thread_rank();
int const blockz = blockDim.x;
for (int i = 0; i < NUM; i++)
{
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus<float>());
#else
// TODO Add implementation here
if (threadIdx.x == 0 && blockIdx.x == 0)
{
printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n");
assert(false);
}
#endif
}
cg::sync(cta);
if (tid == 0)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
float beta = 0.0f;
for (int j = 0; j < blockz; j += 32)
{
beta += cgBlockReduceSumElements_shm[i * blockz + j];
}
element_list[i] = beta;
}
}
}
template <typename T, int MAX_K>
struct TopK
{
int p[MAX_K]; // index, being -1 at the tail if the array is not full
T u[MAX_K]; // value in descend order, being -MAX_T_VAL if the element is invalid
__device__ __forceinline__ void insert(T const elem, int const elem_id)
{
if (elem_id < 0)
{
return;
}
// Condition of updating the array
// 1. array is not full
// 2. elem is greater than the smallest (last) element in the array
// 3. elem is equal to the smallest (last) element in the array but its elem_id is smaller
bool const need_update
= (p[MAX_K - 1] == -1 || elem > u[MAX_K - 1] || elem == u[MAX_K - 1] && elem_id < p[MAX_K - 1]);
if (!need_update)
{
return;
}
// Find suitable index for the new element
int i;
for (i = MAX_K - 2; i >= 0; --i)
{
bool const need_decrease = (p[i] == -1 || elem > u[i] || elem == u[i] && elem_id < p[i]);
if (!need_decrease)
break;
}
// Move elements to correct positions
for (int k = MAX_K - 2; k >= i; --k)
{
p[k + 1] = p[k];
u[k + 1] = u[k];
}
p[i] = elem_id;
u[i] = elem;
}
__device__ __forceinline__ void init()
{
T const MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
for (int i = 0; i < MAX_K; i++)
{
p[i] = -1;
u[i] = -MAX_T_VAL;
}
}
};
template <typename T, int MAX_K>
__device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(TopK<T, MAX_K> const& a, TopK<T, MAX_K> const& b)
{
TopK<T, MAX_K> res = a;
for (int i = 0; i < MAX_K; ++i)
res.insert(b.u[i], b.p[i]);
return res;
}
template <typename T>
struct TopK_2
{
int p = -1;
T u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
__device__ __forceinline__ void insert(T elem, int elem_id)
{
if (elem > u)
{
u = elem;
p = elem_id;
}
}
__device__ __forceinline__ void init()
{
u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
p = -1;
}
};
template <typename T>
__device__ __forceinline__ TopK_2<T> reduce_topk_op_2(TopK_2<T> const& a, TopK_2<T> const& b)
{
return a.u > b.u ? a : b;
}
template <typename T>
__device__ __forceinline__ T clamp_inf_for_half(float const input)
{
return input;
}
template <>
__device__ __forceinline__ half clamp_inf_for_half(float const input)
{
// clamp inf values to enable fp16 training
return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000);
}
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/common/assert.h"
#include <cerrno>
#include <cstdarg>
#include <cstring>
#include <iostream>
#include <string>
namespace tensorrt_llm::common
{
namespace
{
std::string vformat(char const* fmt, va_list args)
{
va_list args0;
va_copy(args0, args);
auto const size = vsnprintf(nullptr, 0, fmt, args0);
if (size <= 0)
return "";
std::string stringBuf(size, char{});
auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args);
TLLM_CHECK_WITH_INFO(size2 == size, std::string(std::strerror(errno)));
return stringBuf;
}
} // namespace
std::string fmtstr(char const* format, ...)
{
va_list args;
va_start(args, format);
std::string result = vformat(format, args);
va_end(args);
return result;
};
std::unordered_set<std::string> str2set(std::string const& input, char delimiter)
{
std::unordered_set<std::string> values;
if (!input.empty())
{
std::stringstream valStream(input);
std::string val;
while (std::getline(valStream, val, delimiter))
{
if (!val.empty())
{
values.insert(val);
}
}
}
return values;
};
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#if ENABLE_BF16
#include <cuda_bf16.h>
#endif // ENABLE_BF16
#include <cuda_fp16.h>
#include <memory> // std::make_unique
#include <sstream> // std::stringstream
#include <string>
#include <unordered_set>
#include <vector>
namespace tensorrt_llm::common
{
#if ENABLE_BF16
static inline std::basic_ostream<char>& operator<<(std::basic_ostream<char>& stream, __nv_bfloat16 const& val)
{
stream << __bfloat162float(val);
return stream;
}
#endif // ENABLE_BF16
static inline std::basic_ostream<char>& operator<<(std::basic_ostream<char>& stream, __half const& val)
{
stream << __half2float(val);
return stream;
}
inline std::string fmtstr(std::string const& s)
{
return s;
}
inline std::string fmtstr(std::string&& s)
{
return s;
}
#if defined(_MSC_VER)
std::string fmtstr(char const* format, ...);
#else
std::string fmtstr(char const* format, ...) __attribute__((format(printf, 1, 2)));
#endif
// __PRETTY_FUNCTION__ is used for neat debugging printing but is not supported on Windows
// The alternative is __FUNCSIG__, which is similar but not identical
#if defined(_WIN32)
#define __PRETTY_FUNCTION__ __FUNCSIG__
#endif
auto constexpr kDefaultDelimiter = ", ";
template <typename U, typename TStream, typename T>
inline TStream& arr2outCasted(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter)
{
out << "(";
if (size > 0)
{
for (size_t i = 0; i < size - 1; ++i)
{
out << static_cast<U>(arr[i]) << delim;
}
out << static_cast<U>(arr[size - 1]);
}
out << ")";
return out;
}
template <typename TStream, typename T>
inline TStream& arr2out(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter)
{
return arr2outCasted<T>(out, arr, size, delim);
}
template <typename T>
inline std::string arr2str(T* arr, size_t size, char const* delim = kDefaultDelimiter)
{
std::stringstream ss;
return arr2out(ss, arr, size, delim).str();
}
template <typename T>
inline std::string vec2str(std::vector<T> const& vec, char const* delim = kDefaultDelimiter)
{
return arr2str(vec.data(), vec.size(), delim);
}
inline bool strStartsWith(std::string const& str, std::string const& prefix)
{
return str.rfind(prefix, 0) == 0;
}
/// @brief Split a string into a set of strings using a delimiter
std::unordered_set<std::string> str2set(std::string const& input, char delimiter);
} // namespace tensorrt_llm::common
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment