Unverified Commit 222ce6f1 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

add tensorrt_llm common and cutlass_extensions as 3rdparty (#3216)


Co-authored-by: default avatarBBuf <35585791+BBuf@users.noreply.github.com>
parent 468d23cf
sgl-kernel/3rdparty/tensorrt_llm/*
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
#
file(GLOB SRCS *.cpp)
file(GLOB CU_SRCS *.cu)
add_library(common_src OBJECT ${SRCS} ${CU_SRCS})
set_property(TARGET common_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET common_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/assert.h"
namespace
{
bool initCheckDebug()
{
auto constexpr kDebugEnabled = "TLLM_DEBUG_MODE";
auto const debugEnabled = std::getenv(kDebugEnabled);
return debugEnabled && debugEnabled[0] == '1';
}
} // namespace
bool DebugConfig::isCheckDebugEnabled()
{
static bool const debugEnabled = initCheckDebug();
return debugEnabled;
}
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cublasVersionCheck.h"
#include <algorithm>
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#endif
namespace tensorrt_llm
{
namespace common
{
CublasMMWrapper::CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace)
: mCublasHandle(cublasHandle)
, mCublasLtHandle(cublasltHandle)
, mStream(stream)
, mCublasWorkspace(workspace)
{
}
CublasMMWrapper::~CublasMMWrapper() {}
CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper)
: mCublasHandle(wrapper.mCublasHandle)
, mCublasLtHandle(wrapper.mCublasLtHandle)
, mStream(wrapper.mStream)
{
}
void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
int const k, int const lda, int const ldb, int const ldc, int8_t fastAcc)
{
// --------------------------------------
// Create descriptors for the original matrices
check_cuda_error(
cublasLtMatrixLayoutCreate(&mADesc, mAType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
check_cuda_error(
cublasLtMatrixLayoutCreate(&mBDesc, mBType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
check_cuda_error(cublasLtMatrixLayoutCreate(&mCDesc, mCType, m, n, ldc));
check_cuda_error(cublasLtMatmulDescCreate(&mOperationDesc, mComputeType, mScaleType));
check_cuda_error(cublasLtMatmulDescSetAttribute(
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof(int8_t)));
}
void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b)
{
check_cuda_error(
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(void*)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(void*)));
}
void CublasMMWrapper::destroyDescriptors()
{
check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(mADesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(mBDesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(mCDesc));
mOperationDesc = NULL;
mADesc = NULL;
mBDesc = NULL;
mCDesc = NULL;
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc)
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f);
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc,
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
{
if (heuristic)
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, /* hasAlgo */ (*heuristic).algo,
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
/* usingCublasLt */ true);
}
else
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, /* hasAlgo */ false,
/* usingCublasLt */ true);
}
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
{
if (heuristic)
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, /* hasAlgo */ (*heuristic).algo,
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
/* usingCublasLt */ true);
}
else
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
/* usingCublasLt */ true);
}
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta)
{
bool usingCublasLt = mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3;
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
/* usingCublasLt */ usingCublasLt);
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt)
{
half h_alpha = (half) (f_alpha);
half h_beta = (half) (f_beta);
// TODO: default cublas libs
usingCublasLt = usingCublasLt && (mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3);
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F;
int batch_count = 1;
// fp32 use cublas as default
// fp16 use cublasLt as default
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
if (usingCublasLt)
{
if (hasAlgo)
{
hasAlgo = checkTactic(transa, transb, m, n, k, lda, ldb, ldc, algo);
}
check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C,
mCDesc, C, mCDesc, (hasAlgo ? (&algo) : NULL), mCublasWorkspace, workspaceSize, mStream));
sync_check_cuda_error();
}
else
{
check_cuda_error(cublasSetStream(getCublasHandle(), mStream));
check_cuda_error(cublasSetWorkspace(getCublasHandle(), mCublasWorkspace, workspaceSize));
// Go with default heuristic to choose tactic as cuBLAS does not allow to choose tactics in Ampere+
cublasGemmAlgo_t cublasAlgo = CUBLAS_GEMM_DEFAULT;
check_cuda_error(cublasGemmEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, B, mBType, ldb,
beta, C, mCType, ldc, mComputeType, static_cast<cublasGemmAlgo_t>(cublasAlgo)));
sync_check_cuda_error();
}
}
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb,
const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha,
float const f_beta)
{
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda,
strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType,
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA,
void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C,
cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType)
{
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda,
strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType,
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
void CublasMMWrapper::setWorkspace(void* workspace)
{
mCublasWorkspace = workspace;
}
void CublasMMWrapper::setFP32GemmConfig()
{
setGemmConfig(CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F);
}
void CublasMMWrapper::setFP16GemmConfig(cudaDataType_t outputType)
{
setGemmConfig(CUDA_R_16F, CUDA_R_16F, outputType, CUDA_R_32F);
}
#ifdef ENABLE_BF16
void CublasMMWrapper::setBF16GemmConfig(cudaDataType_t outputType)
{
setGemmConfig(CUDA_R_16BF, CUDA_R_16BF, outputType, CUDA_R_32F);
}
#endif
#ifdef ENABLE_FP8
void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
{
setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F);
}
#endif
void CublasMMWrapper::setGemmConfig(
cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType)
{
mAType = aType;
mBType = bType;
mCType = cType;
bool isFp16ComputeType = computeType == CUDA_R_16F;
if (isFp16ComputeType)
{
mComputeType = CUBLAS_COMPUTE_16F;
mScaleType = CUDA_R_16F;
}
else
{
mComputeType = CUBLAS_COMPUTE_32F;
mScaleType = CUDA_R_32F;
}
}
CublasDataType CublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
{
if (data_type == CUDA_R_16F)
{
return HALF_DATATYPE;
}
else if (data_type == CUDA_R_32F)
{
return FLOAT_DATATYPE;
}
else if (data_type == CUDA_R_8I)
{
return INT8_DATATYPE;
}
#ifdef ENABLE_BF16
else if (data_type == CUDA_R_16BF)
{
return BFLOAT16_DATATYPE;
}
#endif
return FLOAT_DATATYPE;
}
void CublasMMWrapper::setStream(cudaStream_t stream)
{
mStream = stream;
}
bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo)
{
TLLM_CHECK_WITH_INFO(
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
cublasLtMatmulHeuristicResult_t heurResult;
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc, &algo, &heurResult);
if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS
|| heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE)
{
return false;
}
sync_check_cuda_error();
return true;
}
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasOperation_t transa,
cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc)
{
TLLM_CHECK_WITH_INFO(
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc);
sync_check_cuda_error();
return heuristics;
}
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc)
{
#if TLLM_CUBLAS_VER_LE(11, 4, 2)
TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2.");
return {};
#else
std::vector<cublasLtMatmulHeuristicResult_t> heuristics(200);
cublasLtMatmulPreference_t preference;
check_cuda_error(cublasLtMatmulPreferenceCreate(&preference));
check_cuda_error(cublasLtMatmulPreferenceInit(preference));
uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)));
// Restrict reduction algorithms for numerical stability and better determinism
uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_MASK;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, &reduction_mask, sizeof(reduction_mask)));
#if TLLM_CUBLAS_VER_LT(12, 0, 0)
uint32_t pointer_mode_mask = 0;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask)));
#endif
int return_count = 0;
check_cuda_error(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
heuristics.size(), heuristics.data(), &return_count));
heuristics.resize(return_count);
return heuristics;
#endif
}
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <map>
#include <optional>
#include <string>
namespace tensorrt_llm
{
namespace common
{
class CublasMMWrapper
{
protected:
std::shared_ptr<cublasHandle_t> mCublasHandle;
std::shared_ptr<cublasLtHandle_t> mCublasLtHandle;
cudaDataType_t mAType{};
cudaDataType_t mBType{};
cudaDataType_t mCType{};
cublasComputeType_t mComputeType{};
cudaDataType_t mScaleType{};
cublasLtMatmulDesc_t mOperationDesc{NULL};
cublasLtMatrixLayout_t mADesc{NULL};
cublasLtMatrixLayout_t mBDesc{NULL};
cublasLtMatrixLayout_t mCDesc{NULL};
cudaStream_t mStream;
void* mCublasWorkspace = nullptr;
private:
bool descriptorsCreated() const
{
return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL;
}
public:
CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle, std::shared_ptr<cublasLtHandle_t> cublasLtHandle,
cudaStream_t stream, void* workspace);
~CublasMMWrapper();
CublasMMWrapper(CublasMMWrapper const& wrapper);
/********************** GEMMs **********************/
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc,
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt);
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB,
void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f,
float const f_beta = 0.0f);
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B,
cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType,
int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType);
/********************** Tactic selection helpers **********************/
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo);
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
int const m, int const n, int const k, int const lda, int const ldb, int const ldc);
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc);
using MatrixLayout = std::tuple<cudaDataType_t, cublasLtOrder_t, uint64_t, uint64_t>;
using cache_idx_t = std::tuple<cublasLtMatmulDesc_t, std::array<MatrixLayout, 4>>;
MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc);
/********************** Utils **********************/
void setWorkspace(void* workspace);
void setFP32GemmConfig();
void setFP16GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
#ifdef ENABLE_BF16
void setBF16GemmConfig(cudaDataType_t outputType = CUDA_R_16BF);
#endif
#ifdef ENABLE_FP8
void setFP8GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
#endif
void setStream(cudaStream_t stream);
void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType);
CublasDataType getCublasDataType(cudaDataType_t data_type);
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
int const lda, int const ldb, int const ldc, int8_t fastAcc = 0);
void setScaleDescriptors(void* scale_a, void* scale_b);
void destroyDescriptors();
cublasHandle_t getCublasHandle()
{
return *(this->mCublasHandle);
}
cublasLtHandle_t getCublasLtHandle() const
{
return *(this->mCublasLtHandle);
}
};
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// We don't want to include cublas_api.h. It contains the CUBLAS_VER_* macro
// definition which is not sufficient to determine if we include cublas.h,
// cublas_v2.h or cublasLt.h.
#define TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) (MAJOR * 10000 + MINOR * 100 + PATCH)
#define TLLM_CUBLAS_VER_LE(MAJOR, MINOR, PATCH) \
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
<= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
#define TLLM_CUBLAS_VER_LT(MAJOR, MINOR, PATCH) \
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
< TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
#define TLLM_CUBLAS_VER_GE(MAJOR, MINOR, PATCH) \
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
>= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
#define TLLM_CUBLAS_VER_GT(MAJOR, MINOR, PATCH) \
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
> TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
namespace tensorrt_llm
{
namespace common
{
#ifdef ENABLE_BF16
inline __device__ float2 bf1622float2(const __nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = __low2float(val);
f_val.y = __high2float(val);
return f_val;
#else
return __bfloat1622float2(val);
#endif
}
inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16;
#else
val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.));
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16;
#endif
}
inline __device__ __nv_bfloat162 float22bf162(const float2 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __floats2bfloat162_rn(val.x, val.y);
#else
return __float22bfloat162_rn(val);
#endif
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162 val2;
val2.x = val;
val2.y = val;
return val2;
#else
return __bfloat162bfloat162(val);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
#else
return __hadd2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y));
#else
return __hadd(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
#else
return __hsub2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y));
#else
return __hsub(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
#else
return __hmul2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y));
#else
return __hmul(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh, fzl, fzh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
fzl = __low2float(z);
fzh = __high2float(z);
return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
#else
return __hfma2(x, y, z);
#endif
}
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
#else
return __hfma(x, y, z);
#endif
}
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh;
fxl = __low2float(x);
fxh = __high2float(x);
;
return __floats2bfloat162_rn(expf(fxl), expf(fxh));
#else
return h2exp(x);
#endif
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
__nv_bfloat162 t;
t.x = x;
t.y = y;
return t;
}
#endif
#endif
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
#else
return (__nv_bfloat16) ((float) a + (float) b + (float) c + (float) d);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
fdl = __low2float(d);
fdh = __high2float(d);
return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
#else
return a * b * c + d;
#endif
}
#endif // ENABLE_BF16
} // namespace common
} // namespace tensorrt_llm
// Operator definitions intentionally in global namespace
namespace
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return tensorrt_llm::common::bf16hmul2(x, y);
};
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return tensorrt_llm::common::bf16hadd2(x, y);
};
#endif
#endif
} // namespace
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define CUDA_LIB_NAME "cuda"
#if defined(_WIN32)
#include <windows.h>
#define dllOpen(name) LoadLibrary("nv" name ".dll")
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
#define dllGetSym(handle, name) static_cast<void*>(GetProcAddress(static_cast<HMODULE>(handle), name))
#else // For non-Windows platforms
#include <dlfcn.h>
#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY)
#define dllClose(handle) dlclose(handle)
#define dllGetSym(handle, name) dlsym(handle, name)
#endif // defined(_WIN32)
#include "cudaDriverWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
namespace tensorrt_llm::common
{
std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
{
static std::mutex mutex;
static std::weak_ptr<CUDADriverWrapper> instance;
std::shared_ptr<CUDADriverWrapper> result = instance.lock();
if (result)
{
return result;
}
std::lock_guard<std::mutex> lock(mutex);
result = instance.lock();
if (!result)
{
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
instance = result;
}
return result;
}
CUDADriverWrapper::CUDADriverWrapper()
: handle(dllOpen(CUDA_LIB_NAME))
{
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
auto load_sym = [](void* handle, char const* name)
{
void* ret = dllGetSym(handle, name);
return ret;
};
*reinterpret_cast<void**>(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
*reinterpret_cast<void**>(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage");
*reinterpret_cast<void**>(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
*reinterpret_cast<void**>(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
*reinterpret_cast<void**>(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
*reinterpret_cast<void**>(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
*reinterpret_cast<void**>(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
*reinterpret_cast<void**>(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
*reinterpret_cast<void**>(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
*reinterpret_cast<void**>(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
*reinterpret_cast<void**>(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
*reinterpret_cast<void**>(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
*reinterpret_cast<void**>(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
*reinterpret_cast<void**>(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
*reinterpret_cast<void**>(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
*reinterpret_cast<void**>(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
}
CUDADriverWrapper::~CUDADriverWrapper()
{
dllClose(handle);
}
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const
{
return (*_cuGetErrorName)(error, pStr);
}
CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const
{
return (*_cuGetErrorMessage)(error, pStr);
}
CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const
{
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
}
CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const
{
return (*_cuLinkComplete)(state, cubinOut, sizeOut);
}
CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const
{
return (*_cuModuleUnload)(hmod);
}
CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const
{
return (*_cuLinkDestroy)(state);
}
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const
{
return (*_cuModuleLoadData)(module, image);
}
CUresult CUDADriverWrapper::cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const
{
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
}
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const
{
return (*_cuModuleGetFunction)(hfunc, hmod, name);
}
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const
{
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
}
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path,
unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const
{
return (*_cuLaunchCooperativeKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams);
}
CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const
{
return (*_cuLaunchKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
}
CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const
{
return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides,
boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill);
}
CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const
{
return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount);
}
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUDA_DRIVER_WRAPPER_H
#define CUDA_DRIVER_WRAPPER_H
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
#include <memory>
#include <mutex>
namespace tensorrt_llm::common
{
class CUDADriverWrapper
{
public:
static std::shared_ptr<CUDADriverWrapper> getInstance();
~CUDADriverWrapper();
CUDADriverWrapper(CUDADriverWrapper const&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete;
CUDADriverWrapper(CUDADriverWrapper&&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete;
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
CUresult cuGetErrorMessage(CUresult error, char const** pStr) const;
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
CUresult cuModuleUnload(CUmodule hmod) const;
CUresult cuLinkDestroy(CUlinkState state) const;
CUresult cuModuleLoadData(CUmodule* module, void const* image) const;
CUresult cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const;
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const;
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions,
CUjit_option* options, void** optionValues) const;
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name,
unsigned int numOptions, CUjit_option* options, void** optionValues) const;
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const;
CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra) const;
CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank,
void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim,
cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const;
CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const;
private:
void* handle;
CUDADriverWrapper();
CUresult (*_cuGetErrorName)(CUresult, char const**);
CUresult (*_cuGetErrorMessage)(CUresult, char const**);
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
CUresult (*_cuModuleUnload)(CUmodule);
CUresult (*_cuLinkDestroy)(CUlinkState);
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
CUresult (*_cuModuleLoadData)(CUmodule*, void const*);
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*);
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*);
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLinkAddData)(
CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, CUstream, void**);
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra);
CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);
CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount);
};
template <typename T>
void checkDriver(
T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line)
{
if (result)
{
char const* errorName = nullptr;
char const* errorMsg = nullptr;
wrap.cuGetErrorName(result, &errorName);
wrap.cuGetErrorMessage(result, &errorMsg);
throw TllmException(
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg));
}
}
} // namespace tensorrt_llm::common
/*
* Macros compliant with TensorRT coding conventions
*/
#define TLLM_CU_CHECK(stat) \
do \
{ \
tensorrt_llm::common::checkDriver( \
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
} while (0)
#endif // CUDA_DRIVER_WRAPPER_H
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include <algorithm>
#include <cstdio>
#include <cuda_fp16.h>
#include <limits>
#include <type_traits>
namespace tensorrt_llm
{
namespace common
{
#ifdef ENABLE_FP8
constexpr int CTA_SIZE = 256;
template <bool QUANTIZE>
__inline__ __device__ float scale(float a, float b)
{
return QUANTIZE ? a / b : a * b;
}
template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
{
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
{
if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL)
{
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[i % lda])));
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
{
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[i / lda])));
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR)
{
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[0])));
}
}
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream)
{
dim3 grid(1024);
dim3 block(CTA_SIZE);
if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
scaleMatrix<QuantizeMode::PER_CHANNEL, true>
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TOKEN)
{
scaleMatrix<QuantizeMode::PER_TOKEN, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
scaleMatrix<QuantizeMode::PER_TENSOR, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
sync_check_cuda_error();
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream)
{
dim3 grid(1024);
dim3 block(CTA_SIZE);
if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
scaleMatrix<QuantizeMode::PER_CHANNEL, false>
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TOKEN)
{
scaleMatrix<QuantizeMode::PER_TOKEN, false><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
scaleMatrix<QuantizeMode::PER_TENSOR, false>
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
sync_check_cuda_error();
}
template <typename T_FAKE, typename T_OUT, typename T_IN>
__global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel)
{
for (int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < numel; tid += blockDim.x * gridDim.x)
{
T_FAKE tmp = (T_FAKE) (static_cast<float>(src[tid]));
dst[tid] = (T_OUT) (static_cast<float>(tmp));
}
}
template <typename T_FAKE, typename T_OUT, typename T_IN>
void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream)
{
fakeQuantize<T_FAKE><<<1024, CTA_SIZE, 0, stream>>>(dst, src, numel);
sync_check_cuda_error();
}
template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>(
float* dst, float const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<float, float, __nv_fp8_e4m3>(
float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>(
half* dst, half const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>(
__nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<float, half, float>(
half* dst, float const* src, const int64_t numel, cudaStream_t stream);
__device__ float atomicMaxExtd(float* address, float val)
{
assert(val >= 0);
unsigned int* address_as_u = reinterpret_cast<unsigned int*>(address);
unsigned int old = atomicMax(address_as_u, __float_as_uint(val));
return __uint_as_float(old);
}
template <typename T>
inline __device__ T atomicMaxExtdV2(T* address, T val)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
static_assert(std::is_same_v<T, half> | std::is_same_v<T, __nv_bfloat16>, "T needs to be either half or bfloat16");
// The address in 64 bits.
uint64_t address_u64 = reinterpret_cast<uint64_t const&>(address);
// Pack the input value into 32 bits.
union
{
T v[2];
uint16_t u[2];
} old, tmp = {};
int const loc = (address_u64 & 0x2) >> 1;
tmp.v[loc] = val;
// 4B aligned pointer.
auto aligned_address = reinterpret_cast<T*>(address_u64 & ~0x3ull);
if constexpr (std::is_same_v<T, half>)
{
asm volatile("atom.global.v2.f16.max.noftz {%0, %1}, [%2], {%3, %4};"
: "=h"(old.u[0]), "=h"(old.u[1])
: "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1]));
}
if constexpr (std::is_same_v<T, __nv_bfloat16>)
{
asm volatile("atom.global.v2.bf16.max.noftz {%0, %1}, [%2], {%3, %4};"
: "=h"(old.u[0]), "=h"(old.u[1])
: "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1]));
}
// Return the correct half.
return old.v[loc];
#endif
}
__device__ half atomicMaxExtd(half* address, half val)
{
unsigned short int* address_as_u = reinterpret_cast<unsigned short int*>(address);
unsigned short int old = *address_as_u, assumed;
while (val > __ushort_as_half(old))
{
assumed = old;
old = atomicCAS(address_as_u, assumed, __half_as_ushort(val));
}
return __ushort_as_half(old);
}
__device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
unsigned short int* address_as_u = reinterpret_cast<unsigned short int*>(address);
unsigned short int old = *address_as_u, assumed;
while (val > __ushort_as_bfloat16(old))
{
assumed = old;
old = atomicCAS(address_as_u, assumed, __bfloat16_as_ushort(val));
}
return __ushort_as_bfloat16(old);
#else
assert(0);
asm volatile("brkpt;\n" ::);
return __nv_bfloat16(0);
#endif
}
template <QuantizeMode QUANTIZE_MODE, typename T_S, typename T_W>
__global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t size, const int64_t n)
{
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL)
{
for (int64_t col = threadIdx.x; col < n; col += blockDim.x)
{
float max = 0.f;
for (int64_t i = col + n * blockIdx.x; i < size; i += gridDim.x * n)
{
auto val = fabs(static_cast<float>(weights[i]));
max = max > val ? max : val;
}
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if constexpr (std::is_same_v<T_S, float>)
{
atomicMaxExtd(quant_ptr + col, scale);
}
else
{
auto const address_u64 = reinterpret_cast<uint64_t>(quant_ptr + col);
if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0))
atomicMaxExtd(quant_ptr + col, scale);
else
atomicMaxExtdV2(quant_ptr + col, scale);
}
#else // Vector atomics require __CUDA_ARCH__ >= 900
atomicMaxExtd(quant_ptr + col, scale);
#endif
}
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
{
auto const nrows = size / n;
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
{
float max = 0.f;
for (int64_t i = threadIdx.x; i < n; i += blockDim.x)
{
auto val = fabs(static_cast<float>(weights[row * n + i]));
max = max > val ? max : val;
}
max = blockReduceMax<float>(max);
if (threadIdx.x == 0)
{
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
quant_ptr[row] = scale;
}
}
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR)
{
float max = 0.f;
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x)
{
auto val = fabs(static_cast<float>(weights[i]));
max = max > val ? max : val;
}
max = blockReduceMax<float>(max);
if (threadIdx.x == 0)
{
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
atomicMaxExtd(quant_ptr, scale);
}
}
}
template <typename T_S, typename T_W>
void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t numel, const int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream)
{
if (quantize_mode == QuantizeMode::PER_TOKEN)
{
dim3 block(CTA_SIZE);
dim3 grid(numel / lda);
computeFP8QuantizeScale<QuantizeMode::PER_TOKEN><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
dim3 block(CTA_SIZE);
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_CHANNEL><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
dim3 block(1024);
dim3 grid(1024);
cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_TENSOR><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
}
sync_check_cuda_error();
}
#define DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(type_scale, type_in) \
template void invokeComputeFP8QuantizeScale<type_scale, type_in>(type_scale * input_scale, type_in const* weights, \
int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(half, half);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, half);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, float);
#ifdef ENABLE_BF16
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(__nv_bfloat16, __nv_bfloat16);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, __nv_bfloat16);
#endif
template <typename T_OUT, typename T_S, typename T_IN>
__global__ void dynamicQuantizeMatrixPerToken(
T_OUT* output, T_S* quant_ptr, T_IN const* input, int64_t numel, int64_t lda)
{
extern __shared__ __align__(sizeof(float)) char _shmem[];
T_IN* shmem = reinterpret_cast<T_IN*>(_shmem);
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
auto const nrows = numel / lda;
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
{
float max = 0.f;
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
{
auto const in = input[row * lda + i];
shmem[i] = in;
auto val = fabs(static_cast<float>(in));
max = max > val ? max : val;
}
max = blockAllReduceMax<float>(max); // __syncthreads() called so we can read shmem
auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
{
// true means we are quantizing
output[row * lda + i] = (T_OUT) scale<true>(static_cast<float>(shmem[i]), static_cast<float>(s));
}
if (threadIdx.x == 0)
{
quant_ptr[row] = s;
}
}
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* input, const int64_t numel,
const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream)
{
if (quantize_mode == QuantizeMode::PER_TOKEN)
{
dim3 grid(numel / lda);
bool use_shmem = true;
auto const shmem_size = lda * sizeof(T_IN);
if (shmem_size >= (48 << 10))
{
cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken<T_OUT, T_S, T_IN>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
use_shmem = ret == cudaSuccess;
}
if (use_shmem)
{
// ensure the threadblock is as large as possible to increase occupancy
dim3 block(std::min((lda + 31) / 32 * 32, static_cast<int64_t>(1024)));
dynamicQuantizeMatrixPerToken<<<grid, block, shmem_size, stream>>>(output, quant_ptr, input, numel, lda);
}
else
{
dim3 block(CTA_SIZE);
computeFP8QuantizeScale<QuantizeMode::PER_TOKEN><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
sync_check_cuda_error();
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
}
}
else if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
dim3 block(CTA_SIZE);
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_CHANNEL><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
sync_check_cuda_error();
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
dim3 block(1024);
dim3 grid(1024);
cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_TENSOR><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
sync_check_cuda_error();
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
}
sync_check_cuda_error();
}
#define DEFINE_INVOKE_QUANTIZE_MATRIX(type_out, type_scale, type_in) \
template void invokeQuantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
cudaStream_t stream); \
template void invokeDequantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
cudaStream_t stream); \
template void invokeComputeScalesAndQuantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
type_scale * input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
cudaStream_t stream);
#ifdef ENABLE_FP8
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, float);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, half);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, half, half);
DEFINE_INVOKE_QUANTIZE_MATRIX(half, half, __nv_fp8_e4m3);
DEFINE_INVOKE_QUANTIZE_MATRIX(float, float, __nv_fp8_e4m3);
DEFINE_INVOKE_QUANTIZE_MATRIX(half, float, __nv_fp8_e4m3);
#ifdef ENABLE_BF16
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3);
#endif
#endif
#endif // ENABLE_FP8
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaProfilerUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/stringUtils.h"
#include <cstdint>
#include <optional>
namespace
{
std::tuple<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexesImpl(
std::string const& envVarName)
{
auto envVarVal = std::getenv(envVarName.c_str());
auto envVarValStr = std::string{envVarVal != nullptr ? envVarVal : ""};
auto values = tensorrt_llm::common::str2set(envVarValStr, ',');
std::unordered_set<int32_t> startSet;
std::unordered_set<int32_t> endSet;
for (std::string const& value : values)
{
size_t dashIdx = value.find("-");
if (dashIdx != std::string::npos)
{
int32_t start = std::stoi(value.substr(0, dashIdx));
startSet.insert(start);
int32_t end = std::stoi(value.substr(dashIdx + 1));
endSet.insert(end);
}
else
{
int32_t start_end = std::stoi(value);
startSet.insert(start_end);
endSet.insert(start_end);
}
}
return std::make_pair(startSet, endSet);
}
} // namespace
namespace tensorrt_llm::common
{
std::pair<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexes(
std::string const& envVarName, std::optional<std::string> const& legacyEnvVarName)
{
auto [profileIterIdxs, stopIterIdxs] = populateIterationIndexesImpl(envVarName);
// If empty, try to use legacy env var name
if (legacyEnvVarName && profileIterIdxs.empty() && stopIterIdxs.empty())
{
std::tie(profileIterIdxs, stopIterIdxs) = populateIterationIndexesImpl(legacyEnvVarName.value());
if (!profileIterIdxs.empty() || !stopIterIdxs.empty())
{
TLLM_LOG_WARNING(
"Using deprecated environment variable %s to specify cudaProfiler start and stop iterations. "
"Please "
"use %s "
"instead.",
legacyEnvVarName.value().c_str(), envVarName.c_str());
}
}
return std::make_pair(profileIterIdxs, stopIterIdxs);
}
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include <assert.h>
#include <cuda.h>
#include <cuda_fp16.h>
#if ENABLE_BF16
#include <cuda_bf16.h>
#endif
namespace tensorrt_llm
{
namespace common
{
template <typename T>
inline __device__ T ldg(T const* val)
{
return __ldg(val);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
template <>
inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
#endif // ENABLE_BF16
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter
{
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2>
{
using Type = half;
};
template <>
struct TypeConverter<half>
{
using Type = half2;
};
#if ENABLE_BF16
template <>
struct TypeConverter<__nv_bfloat162>
{
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16>
{
using Type = __nv_bfloat162;
};
#endif // ENABLE_BF16
// Defined math operations (bfloat16 fallback to fp32 when it is not supported)
template <typename T>
inline __device__ T hadd2(T a, T b)
{
return __hadd2(a, b);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T add(T a, T b)
{
return a + b;
}
template <>
inline __device__ half2 add(half2 a, half2 b)
{
return __hadd2(a, b);
}
template <>
inline __device__ half add(half a, half b)
{
return __hadd(a, b);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b);
}
template <>
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
{
return bf16hadd(a, b);
}
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b)
{
return bf16hadd(a, __float2bfloat16(b));
}
#endif // ENABLE_BF16
// applies to all 4 values addition
template <typename T>
inline __device__ T add(T a, T b, T c)
{
return a + b + c;
}
#if ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hadd(a, b, c);
}
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hadd2(a, b, c);
}
#endif // ENABLE_BF16
// applies to all 4 values addition
template <typename T>
inline __device__ T add(T a, T b, T c, T d)
{
return (T) ((float) a + (float) b + (float) c + (float) d);
}
#if ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
return bf16hadd(a, b, c, d);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T hsub2(T a, T b)
{
return __hsub2(a, b);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hsub2(a, b);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T hmul2(T a, T b)
{
return __hmul2(a, b);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hmul2(a, b);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T hmul2(T a, T b, T c)
{
return a * b * c;
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hmul2(a, b, c);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T mul(T a, T b, T c)
{
return a * b * c;
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hmul(a, b, c);
}
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hmul2(a, b, c);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T fma(T a, T b, T c, T d)
{
return a * b * c + d;
}
#if ENABLE_BF16
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
return bf16hfma2(a, b, c, d);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T fma(T a, T b, T c)
{
return a * b + c;
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hfma2(a, b, c);
}
template <>
inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hfma(a, b, c);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T hexp2(T a)
{
return h2exp(a);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a)
{
return bf16exp2(a);
}
#endif // ENABLE_BF16
template <typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val)
{
return val;
}
template <>
__device__ inline float2 cuda_cast<float2, int2>(int2 val)
{
return make_float2(val.x, val.y);
}
template <>
__device__ inline float2 cuda_cast<float2, float>(float val)
{
return make_float2(val, val);
}
template <>
__device__ inline float2 cuda_cast<float2, half2>(half2 val)
{
return __half22float2(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float2>(float2 val)
{
return __float22half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float>(float val)
{
return __float2half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, half>(half val)
{
return __half2half2(val);
}
template <>
__device__ inline int8_t cuda_cast<int8_t, half>(half val)
{
union
{
int8_t int8[2];
int16_t int16;
};
union
{
half fp16;
int16_t int16_in;
};
fp16 = val;
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline int8_t cuda_cast<int8_t, float>(float val)
{
union
{
int8_t int8[2];
int16_t int16;
};
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_half2(int8[0], int8[1]);
}
template <>
__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_float2(int8[0], int8[1]);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_cast(int32_t val)
{
return static_cast<float>(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast(int8_t val)
{
return static_cast<float>(val);
}
template <>
__device__ inline int8_t cuda_cast(__nv_bfloat16 val)
{
return static_cast<float>(val);
}
template <>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val)
{
return __bfloat162float(val);
}
template <>
__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622float2(val);
}
template <>
__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val)
{
return __float2half(__bfloat162float(val));
}
template <>
__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622int16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val)
{
return __float2bfloat16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val)
{
return __float2bfloat16(__half2float(val));
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val)
{
return bf162bf162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val)
{
return __float2bfloat162_rn(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val)
{
return float22bf162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
__nv_bfloat162 res;
res.x = cuda_cast<__nv_bfloat16>(int8[0]);
res.y = cuda_cast<__nv_bfloat16>(int8[1]);
return res;
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
{
return float22bf162(__half22float2(val));
}
#endif // ENABLE BF16
template <typename T>
__device__ inline T cuda_abs(T val)
{
assert(false);
return {};
}
template <>
__device__ inline float cuda_abs(float val)
{
return fabs(val);
}
template <>
__device__ inline float2 cuda_abs(float2 val)
{
return make_float2(fabs(val.x), fabs(val.y));
}
template <>
__device__ inline half cuda_abs(half val)
{
return __habs(val);
}
template <>
__device__ inline half2 cuda_abs(half2 val)
{
return __habs2(val);
}
#ifdef ENABLE_BF16
#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
template <>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
return __habs(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
return __habs2(val);
}
#endif
#endif // ENABLE_FP16
template <typename To, typename Ti>
__device__ inline To cuda_sum(Ti val)
{
return cuda_cast<To>(val);
};
template <typename To>
__device__ inline To cuda_sum(float2 val)
{
return cuda_cast<To>(val.x + val.y);
};
// Unary maximum: compute the max of a vector type
template <typename To, typename Ti>
__device__ inline To cuda_max(Ti val)
{
return cuda_cast<To>(val);
};
template <>
__device__ inline float cuda_max(float2 val)
{
return fmaxf(val.x, val.y);
}
template <>
__device__ inline half cuda_max(half2 val)
{
return __hmax(val.x, val.y);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __hmax(val.x, val.y);
#else
assert(0);
asm volatile("brkpt;\n" ::);
return __nv_bfloat16(0);
#endif
}
#endif
// Binary maximum: compute the max of two values.
template <typename T>
__device__ inline T cuda_max(T val1, T val2)
{
return (val1 > val2) ? val1 : val2;
}
template <>
__device__ inline float2 cuda_max(float2 val1, float2 val2)
{
float2 out;
out.x = fmaxf(val1.x, val2.x);
out.y = fmaxf(val1.y, val2.y);
return out;
}
template <>
__device__ inline half2 cuda_max(half2 val1, half2 val2)
{
return __hmax2(val1, val2);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat162 cuda_max(__nv_bfloat162 val1, __nv_bfloat162 val2)
{
return __hmax2(val1, val2);
}
#endif // ENABLE_BF16
// Binary maximum: compute the min of two values.
template <typename T>
__device__ inline T cuda_min(T val1, T val2)
{
return (val1 < val2) ? val1 : val2;
}
template <>
__device__ inline float2 cuda_min(float2 val1, float2 val2)
{
float2 out;
out.x = fminf(val1.x, val2.x);
out.y = fminf(val1.y, val2.y);
return out;
}
template <>
__device__ inline half2 cuda_min(half2 val1, half2 val2)
{
return __hmin2(val1, val2);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat162 cuda_min(__nv_bfloat162 val1, __nv_bfloat162 val2)
{
return __hmin2(val1, val2);
}
#endif // ENABLE_BF16
// Helper function of clamping the val into the given range.
template <typename T>
inline __device__ T cuda_clamp(T val, T minVal, T maxVal)
{
return cuda_min(cuda_max(val, minVal), maxVal);
}
#ifdef ENABLE_FP8
template <>
__device__ inline float2 cuda_cast<float2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
return bf1622float2(fp8x2_e4m3_to_bfloat2(&val));
}
template <>
__device__ inline half2 cuda_cast<half2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
return fp8x2_e4m3_to_half2(&val);
}
template <>
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val)
{
return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val)));
}
template <>
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, half2>(half2 val)
{
return __nv_fp8x2_e4m3(cuda_cast<float2>(val));
}
template <>
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, __nv_bfloat162>(__nv_bfloat162 val)
{
return __nv_fp8x2_e4m3(cuda_cast<float2>(val));
}
template <>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val)
{
return __nv_fp8_e4m3(val);
}
template <>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val)
{
return __nv_fp8_e4m3(val);
}
template <>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val)
{
return __nv_fp8_e4m3(val);
}
template <>
__device__ inline float cuda_cast<float, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
{
return (float) val;
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
return fp8x2_e4m3_to_bfloat2(&val);
}
template <>
__device__ inline int8_t cuda_cast<int8_t, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
{
// no impl
return 0;
}
template <>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val)
{
return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast<float>(val)));
}
#endif // ENABLE_FP8
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstddef>
namespace tensorrt_llm::utils::customAllReduceUtils
{
constexpr size_t NUM_POINTERS_PER_RANK = 7;
// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py
inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
{
if (worldSize <= 2)
{
return 16 * 1000 * 1000;
}
return 8 * 1000 * 1000;
}
} // namespace tensorrt_llm::utils::customAllReduceUtils
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "envUtils.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include <cstdlib>
namespace tensorrt_llm::common
{
std::optional<int32_t> getIntEnv(char const* name)
{
char const* const env = std::getenv(name);
if (env == nullptr)
{
return std::nullopt;
}
int32_t const val = std::stoi(env);
if (val <= 0)
{
return std::nullopt;
}
return {val};
};
// Returns true if the env variable exists and is set to "1"
static bool getBoolEnv(char const* name)
{
char const* env = std::getenv(name);
return env && env[0] == '1' && env[1] == '\0';
}
// XQA kernels (optimized kernels for generation phase).
bool forceXQAKernels()
{
static bool const forceXQA = (getIntEnv("TRTLLM_FORCE_XQA").value_or(0) != 0);
return forceXQA;
}
std::optional<bool> getEnvEnableXQAJIT()
{
static bool init = false;
static bool exists = false;
static bool enableXQAJIT = false;
if (!init)
{
init = true;
char const* enable_xqa_jit_var = std::getenv("TRTLLM_ENABLE_XQA_JIT");
if (enable_xqa_jit_var)
{
exists = true;
if (enable_xqa_jit_var[0] == '1' && enable_xqa_jit_var[1] == '\0')
{
enableXQAJIT = true;
}
}
}
if (exists)
{
return enableXQAJIT;
}
else
{
return std::nullopt;
}
}
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug()
{
static bool init = false;
static bool forceMmhaMaxSeqLenTile = false;
if (!init)
{
init = true;
char const* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG");
if (enable_mmha_debug_var)
{
if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0')
{
forceMmhaMaxSeqLenTile = true;
}
}
}
return forceMmhaMaxSeqLenTile;
}
int getEnvMmhaBlocksPerSequence()
{
static bool init = false;
static int mmhaBlocksPerSequence = 0;
if (!init)
{
init = true;
char const* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE");
if (mmhaBlocksPerSequenceEnv)
{
mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv);
if (mmhaBlocksPerSequence <= 0)
{
TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_BLOCKS_PER_SEQUENCE. Will use default values instead!");
}
}
}
return mmhaBlocksPerSequence;
}
int getEnvMmhaKernelBlockSize()
{
static bool init = false;
static int mmhaKernelBlockSize = 0;
if (!init)
{
init = true;
char const* mmhaKernelBlockSizeEnv = std::getenv("TRTLLM_MMHA_KERNEL_BLOCK_SIZE");
if (mmhaKernelBlockSizeEnv)
{
mmhaKernelBlockSize = std::atoi(mmhaKernelBlockSizeEnv);
if (mmhaKernelBlockSize <= 0)
{
TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_KERNEL_BLOCK_SIZE. Will use default values instead!");
}
}
}
return mmhaKernelBlockSize;
}
bool getEnvEnablePDL()
{
static bool init = false;
static bool enablePDL = false;
if (!init)
{
init = true;
// PDL only available when arch >= 90
if (getSMVersion() >= 90)
{
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
}
}
return enablePDL;
}
bool getEnvUseUCXKvCache()
{
static bool const useUCXKVCache = getBoolEnv("TRTLLM_USE_UCX_KVCACHE");
return useUCXKVCache;
}
std::string getEnvUCXInterface()
{
static bool init = false;
static std::string ucxInterface;
if (!init)
{
init = true;
{
char const* ucx_interface = std::getenv("TRTLLM_UCX_INTERFACE");
if (ucx_interface)
{
ucxInterface = ucx_interface;
}
}
}
return ucxInterface;
}
bool getEnvDisaggLayerwise()
{
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
return disaggLayerwise;
}
bool getEnvParallelCacheSend()
{
static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND");
return parallelCacheSend;
}
bool getEnvRequestKVCacheSerial()
{
static bool const requestKVCacheSerial = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_SERIAL");
return requestKVCacheSerial;
}
bool getEnvDisableKVCacheTransferOverlap()
{
static bool const disableKVCacheTransferOverlap = getBoolEnv("TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP");
return disableKVCacheTransferOverlap;
}
bool getEnvDisableReceiveKVCacheParallel()
{
static bool const disableReceiveParallel = getBoolEnv("TRTLLM_DISABLE_KVCACHE_RECEIVE_PARALLEL");
return disableReceiveParallel;
}
} // namespace tensorrt_llm::common
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <optional>
#include <string>
namespace tensorrt_llm::common
{
// Useful when you want to inject some debug code controllable with env var.
std::optional<int32_t> getIntEnv(char const* name);
// XQA kernels (optimized kernels for generation phase).
bool forceXQAKernels();
// Whether XQA JIT is enabled.
//
// Returns the value of TRTLLM_ENABLE_XQA_JIT env var. If such env var doesn't exist, std::nullopt is returned.
std::optional<bool> getEnvEnableXQAJIT();
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug();
int getEnvMmhaBlocksPerSequence();
int getEnvMmhaKernelBlockSize();
// Whether PDL is enabled.
bool getEnvEnablePDL();
bool getEnvUseUCXKvCache();
std::string getEnvUCXInterface();
bool getEnvDisaggLayerwise();
bool getEnvParallelCacheSend();
bool getEnvRequestKVCacheSerial();
bool getEnvDisableKVCacheTransferOverlap();
bool getEnvDisableReceiveKVCacheParallel();
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/tllmException.h"
#include <cuda_runtime.h>
namespace tensorrt_llm::common
{
Logger::Logger()
{
char* isFirstRankOnlyChar = std::getenv("TLLM_LOG_FIRST_RANK_ONLY");
bool isFirstRankOnly = (isFirstRankOnlyChar != nullptr && std::string(isFirstRankOnlyChar) == "ON");
auto const* levelName = std::getenv("TLLM_LOG_LEVEL");
if (levelName != nullptr)
{
auto level = [levelName = std::string(levelName)]()
{
if (levelName == "TRACE")
return TRACE;
if (levelName == "DEBUG")
return DEBUG;
if (levelName == "INFO")
return INFO;
if (levelName == "WARNING")
return WARNING;
if (levelName == "ERROR")
return ERROR;
TLLM_THROW("Invalid log level: %s", levelName.c_str());
}();
// If TLLM_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR
if (isFirstRankOnly)
{
auto const deviceId = getDevice();
if (deviceId != 1)
{
level = ERROR;
}
}
setLevel(level);
}
}
void Logger::log(std::exception const& ex, Logger::Level level)
{
log(level, "%s: %s", TllmException::demangle(typeid(ex).name()).c_str(), ex.what());
}
Logger* Logger::getLogger()
{
thread_local Logger instance;
return &instance;
}
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime.h>
namespace tensorrt_llm
{
namespace common
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
inline __device__ __host__ T divUp(T m, T n)
{
return (m + n - 1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace common
} // namespace tensorrt_llm
This diff is collapsed.
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include <cassert>
namespace tensorrt_llm
{
namespace common
{
template <typename T>
void deviceMalloc(T** ptr, size_t size, bool is_random_initialize = true);
template <typename T>
void deviceMemSetZero(T* ptr, size_t size);
template <typename T>
void deviceFree(T*& ptr);
template <typename T>
void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0);
template <typename T>
void cudaD2Hcpy(T* tgt, T const* src, size_t const size);
template <typename T>
void cudaH2Dcpy(T* tgt, T const* src, size_t const size);
template <typename T>
void cudaD2Dcpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL);
template <typename T>
void cudaAutoCpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL);
template <typename T>
void cudaRandomUniform(T* buffer, size_t const size);
template <typename T>
int loadWeightFromBin(T* ptr, std::vector<size_t> shape, std::string filename,
TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32);
// template<typename T>
// int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr,
// T* scale_ptr,
// std::vector<size_t> shape,
// std::string filename,
// TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32);
void invokeCudaD2DcpyHalf2Float(float* dst, half* src, size_t const size, cudaStream_t stream);
void invokeCudaD2DcpyFloat2Half(half* dst, float* src, size_t const size, cudaStream_t stream);
#ifdef ENABLE_FP8
void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream);
void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream);
void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, size_t const size, cudaStream_t stream);
void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, size_t const size, cudaStream_t stream);
void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream);
#endif // ENABLE_FP8
#ifdef ENABLE_BF16
void invokeCudaD2DcpyBfloat2Float(float* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream);
#endif // ENABLE_BF16
template <typename T_OUT, typename T_IN>
void invokeCudaCast(T_OUT* dst, T_IN const* const src, size_t const size, cudaStream_t stream);
////////////////////////////////////////////////////////////////////////////////////////////////////
// The following functions implement conversion of multi-dimensional indices to an index in a flat array.
// The shape of the Tensor dimensions is passed as one array (`dims`), the indices are given as individual arguments.
// For examples on how to use these functions, see their tests `test_memory_utils.cu`.
// All of these functions can be evaluated at compile time by recursive template expansion.
template <typename TDim, typename T, typename TIndex>
__inline__ __host__ __device__ std::enable_if_t<std::is_pointer<TDim>::value, T> constexpr flat_index(
T const& acc, TDim dims, TIndex const& index)
{
assert(index < dims[0]);
return acc * dims[0] + index;
}
template <typename TDim, typename T, typename TIndex, typename... TIndices>
__inline__ __host__ __device__ std::enable_if_t<std::is_pointer<TDim>::value, T> constexpr flat_index(
T const& acc, TDim dims, TIndex const& index, TIndices... indices)
{
assert(index < dims[0]);
return flat_index(acc * dims[0] + index, dims + 1, indices...);
}
template <typename TDim, typename T>
__inline__ __host__ __device__ std::enable_if_t<std::is_pointer<TDim>::value, T> constexpr flat_index(
[[maybe_unused]] TDim dims, T const& index)
{
assert(index < dims[0]);
return index;
}
template <typename TDim, typename TIndex, typename... TIndices>
__inline__ __host__ __device__
std::enable_if_t<std::is_pointer<TDim>::value, typename std::remove_pointer<TDim>::type> constexpr flat_index(
TDim dims, TIndex const& index, TIndices... indices)
{
assert(index < dims[0]);
return flat_index(static_cast<typename std::remove_pointer<TDim>::type>(index), dims + 1, indices...);
}
template <unsigned skip = 0, typename T, std::size_t N, typename TIndex, typename... TIndices>
__inline__ __host__ __device__ T constexpr flat_index(
std::array<T, N> const& dims, TIndex const& index, TIndices... indices)
{
static_assert(skip < N);
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
return flat_index(&dims[skip], index, indices...);
}
template <unsigned skip = 0, typename T, typename TIndex, std::size_t N, typename... TIndices>
__inline__ __host__ __device__ T constexpr flat_index(
T const& acc, std::array<T, N> const& dims, TIndex const& index, TIndices... indices)
{
static_assert(skip < N);
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
return flat_index(acc, &dims[skip], index, indices...);
}
template <unsigned skip = 0, typename T, typename TIndex, std::size_t N, typename... TIndices>
__inline__ __host__ __device__ T constexpr flat_index(T const (&dims)[N], TIndex const& index, TIndices... indices)
{
static_assert(skip < N);
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
return flat_index(static_cast<T const*>(dims) + skip, index, indices...);
}
template <unsigned skip = 0, typename T, typename TIndex, std::size_t N, typename... TIndices>
__inline__ __host__ __device__ T constexpr flat_index(
T const& acc, T const (&dims)[N], TIndex const& index, TIndices... indices)
{
static_assert(skip < N);
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
return flat_index(acc, static_cast<T const*>(dims) + skip, index, indices...);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// These are simpler functions for multi-dimensional index conversion. Indices and dimensions are passed as individual
// arguments. These functions are more suitable for usage inside kernels than the corresponding flat_index functions
// which require arrays as arguments. Usage examples can be found in `test_memory_utils.cu`. The functions can be
// evaluated at compile time.
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index2(TIndex const& index_0, TIndex const& index_1, T const& dim_1)
{
assert(index_1 < dim_1);
return index_0 * dim_1 + index_1;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index3(
TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& dim_1, T const& dim_2)
{
assert(index_2 < dim_2);
return flat_index2(index_0, index_1, dim_1) * dim_2 + index_2;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index4(TIndex const& index_0, TIndex const& index_1,
TIndex const& index_2, TIndex const& index_3, T const& dim_1, T const& dim_2, T const& dim_3)
{
assert(index_3 < dim_3);
return flat_index3(index_0, index_1, index_2, dim_1, dim_2) * dim_3 + index_3;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index5(TIndex const& index_0, TIndex const& index_1,
TIndex const& index_2, TIndex const& index_3, TIndex const& index_4, T const& dim_1, T const& dim_2, T const& dim_3,
T const& dim_4)
{
assert(index_4 < dim_4);
return flat_index4(index_0, index_1, index_2, index_3, dim_1, dim_2, dim_3) * dim_4 + index_4;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index_strided3(
TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& stride_1, T const& stride_2)
{
assert(index_1 < stride_1 / stride_2);
assert(index_2 < stride_2);
return index_0 * stride_1 + index_1 * stride_2 + index_2;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index_strided4(TIndex const& index_0, TIndex const& index_1,
TIndex const& index_2, TIndex const& index_3, T const& stride_1, T const& stride_2, T const& stride_3)
{
assert(index_1 < stride_1 / stride_2);
assert(index_2 < stride_2 / stride_3);
assert(index_3 < stride_3);
return index_0 * stride_1 + index_1 * stride_2 + index_2 * stride_3 + index_3;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
void invokeInPlaceTranspose(T* data, T* workspace, size_t const dim0, size_t const dim1);
template <typename T>
void invokeInPlaceTranspose0213(
T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2, size_t const dim3);
template <typename T>
void invokeInPlaceTranspose102(T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2);
template <typename T>
void invokeMultiplyScale(T* tensor, float scale, size_t const size, cudaStream_t stream);
template <typename T>
void invokeDivideScale(T* tensor, float scale, size_t const size, cudaStream_t stream);
template <typename T_IN, typename T_OUT>
void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, size_t const size, cudaStream_t stream = 0);
template <typename T_IN, typename T_OUT>
void invokeCudaD2DScaleCpyConvert(
T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, size_t const size, cudaStream_t stream = 0);
inline bool checkIfFileExist(std::string const& file_path)
{
std::ifstream in(file_path, std::ios::in | std::ios::binary);
if (in.is_open())
{
in.close();
return true;
}
return false;
}
template <typename T>
void saveToBinary(T const* ptr, size_t const size, std::string filename);
template <typename T_IN, typename T_fake_type>
void invokeFakeCast(T_IN* input_ptr, size_t const size, cudaStream_t stream);
size_t cuda_datatype_size(TRTLLMCudaDataType dt);
template <typename T>
bool invokeCheckRange(T const* buffer, size_t const size, T min, T max, bool* d_within_range, cudaStream_t stream);
constexpr size_t DEFAULT_ALIGN_BYTES = 256;
size_t calcAlignedSize(std::vector<size_t> const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES);
void calcAlignedPointers(std::vector<void*>& outPtrs, void const* p, std::vector<size_t> const& sizes,
size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES);
struct AlignedPointersUnpacker
{
template <typename... T>
void operator()(T*&... outPtrs)
{
assert(sizeof...(T) == alignedPointers.size());
auto it = alignedPointers.begin();
((outPtrs = static_cast<T*>(*it++)), ...);
}
std::vector<void*> alignedPointers;
};
AlignedPointersUnpacker inline calcAlignedPointers(
void const* p, std::vector<size_t> const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES)
{
AlignedPointersUnpacker unpacker{};
calcAlignedPointers(unpacker.alignedPointers, p, sizes, ALIGN_BYTES);
return unpacker;
}
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <numeric>
#include <unordered_set>
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include <csignal>
#include <cstdlib>
#include <mutex>
#include <thread>
#include <type_traits>
#ifndef _WIN32
#include <unistd.h>
#endif
// We rely on SizeType32 being int32_t in some places with weak type checking,
// i.e. we're passing void ptr to some function. To prevent mysterious errors
// in the future, we trigger a compilation error here if SizeType32 isn't int32_t.
static_assert(std::is_same<tensorrt_llm::runtime::SizeType32, std::int32_t>::value);
namespace tensorrt_llm::mpi
{
MPI_Datatype getMpiDtype(MpiType dtype)
{
#if ENABLE_MULTI_DEVICE
static std::unordered_map<MpiType, MPI_Datatype> const dtype_map{
{MpiType::kBYTE, MPI_BYTE},
{MpiType::kHALF, MPI_UINT16_T},
{MpiType::kFLOAT, MPI_FLOAT},
{MpiType::kDOUBLE, MPI_DOUBLE},
{MpiType::kBOOL, MPI_C_BOOL},
{MpiType::kINT8, MPI_INT8_T},
{MpiType::kUINT8, MPI_UINT8_T},
{MpiType::kINT32, MPI_INT32_T},
{MpiType::kUINT32, MPI_UINT32_T},
{MpiType::kINT64, MPI_INT64_T},
{MpiType::kUINT64, MPI_UINT64_T},
{MpiType::kFP8, MPI_UINT8_T},
{MpiType::kBF16, MPI_UINT16_T},
{MpiType::kCHAR, MPI_CHAR},
};
return dtype_map.at(dtype);
#else
TLLM_THROW("Multi device support is disabled.");
#endif
}
MPI_Op getMpiOp(MpiOp op)
{
#if ENABLE_MULTI_DEVICE
static std::unordered_map<MpiOp, MPI_Op> const op_map{
{MpiOp::NULLOP, MPI_OP_NULL},
{MpiOp::MAX, MPI_MAX},
{MpiOp::MIN, MPI_MIN},
{MpiOp::SUM, MPI_SUM},
{MpiOp::PROD, MPI_PROD},
{MpiOp::LAND, MPI_LAND},
{MpiOp::BAND, MPI_BAND},
{MpiOp::LOR, MPI_LOR},
{MpiOp::BOR, MPI_BOR},
{MpiOp::LXOR, MPI_LXOR},
{MpiOp::BXOR, MPI_BXOR},
{MpiOp::MINLOC, MPI_MINLOC},
{MpiOp::MAXLOC, MPI_MAXLOC},
{MpiOp::REPLACE, MPI_REPLACE},
};
return op_map.at(op);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
namespace
{
bool mpiInitialized = false;
std::recursive_mutex mpiMutex;
MpiComm initLocalSession()
{
#if ENABLE_MULTI_DEVICE
MPI_Comm localComm = nullptr;
MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, &localComm);
MpiComm localSession{localComm, false};
#else
MpiComm localSession{COMM_SESSION, false};
#endif // ENABLE_MULTI_DEVICE
return localSession;
}
} // namespace
std::vector<int> getWorldRanks(MpiComm const& comm)
{
#if ENABLE_MULTI_DEVICE
MPI_Group group = nullptr;
MPI_Group worldGroup = nullptr;
MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
MPICHECK(MPI_Comm_group(comm, &group));
int groupSize = 0;
MPICHECK(MPI_Group_size(group, &groupSize));
std::vector<int> ranks(groupSize);
std::vector<int> worldRanks(groupSize);
std::iota(ranks.begin(), ranks.end(), 0);
MPICHECK(MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data()));
MPICHECK(MPI_Group_free(&group));
MPICHECK(MPI_Group_free(&worldGroup));
#else
std::vector<int> worldRanks{0};
#endif
return worldRanks;
}
void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent)
{
// double-checked locking
if (mpiInitialized)
{
return;
}
std::lock_guard<std::recursive_mutex> lk(mpiMutex);
if (mpiInitialized)
{
return;
}
#if ENABLE_MULTI_DEVICE
int initialized = 0;
TLLM_MPI_CHECK(MPI_Initialized(&initialized));
if (!initialized)
{
TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode);
int providedMode = 0;
auto requiredMode = static_cast<int>(threadMode);
MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode));
TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed");
std::atexit([]() { MPI_Finalize(); });
/*
* We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause one of these 2
* signals. Signals like SIGINT and SIGTERM should be issued to the parent and should terminate MPI workers
* correctly.
*/
for (int sig : {SIGABRT, SIGSEGV})
{
__sighandler_t previousHandler = nullptr;
if (forwardAbortToParent)
{
previousHandler = std::signal(sig,
[](int signal)
{
#ifndef _WIN32
pid_t parentProcessId = getppid();
kill(parentProcessId, SIGKILL);
#endif
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
});
}
else
{
previousHandler = std::signal(sig, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); });
}
TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed");
}
// ensure local MPI communicator is initialized
MpiComm::localSession();
TLLM_LOG_INFO("Initialized MPI");
}
#endif // ENABLE_MULTI_DEVICE
mpiInitialized = true;
}
void MpiComm::barrier() const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Barrier(mComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
#if ENABLE_MULTI_DEVICE
template <typename TMpiFunc, typename TBase, typename... TArgs,
typename = std::enable_if_t<std::is_same_v<void, std::remove_const_t<TBase>>>>
size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dtype, TArgs... args)
{
constexpr auto maxP1 = static_cast<size_t>(std::numeric_limits<int>::max()) + 1;
if (TLLM_LIKELY(size < maxP1))
{
MPICHECK(func(buffer, size, dtype, args...));
return 1;
}
constexpr size_t alignment = 256;
int elementSize = 1;
MPICHECK(MPI_Type_size(dtype, &elementSize));
elementSize = std::min<int>(elementSize, alignment);
// We cap at max alignment-bytes chunks that can be sent at once.
auto const step = maxP1 - (alignment / elementSize);
using TCast = std::conditional_t<std::is_const_v<TBase>, uint8_t const, uint8_t>;
size_t count = 0;
while (size != 0)
{
auto currentStep = static_cast<int>(std::min(size, step));
MPICHECK(func(buffer, currentStep, dtype, args...));
size -= currentStep;
size_t diff = static_cast<size_t>(currentStep) * elementSize;
buffer = static_cast<TCast*>(buffer) + diff;
++count;
}
return count;
}
#endif // ENABLE_MULTI_DEVICE
std::shared_ptr<MpiRequest> MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const
{
std::shared_ptr<MpiRequest> r = std::make_shared<MpiRequest>();
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
return r;
}
std::shared_ptr<MpiRequest> MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const
{
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}
void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
{
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
void MpiComm::bcast(runtime::IBuffer& buf, int root) const
{
bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}
std::shared_ptr<MpiRequest> MpiComm::sendAsync(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
{
TLLM_LOG_DEBUG("start MPI_Isend with size %d", size);
std::shared_ptr<MpiRequest> r = std::make_shared<MpiRequest>();
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Isend, buffer, size, getMpiDtype(dtype), dest, tag, mComm, &r->mRequest);
#else
TLLM_THROW("Multi device support is disabled.");
#endif
TLLM_LOG_DEBUG("end MPI_Isend with size %d", size);
return r;
}
std::shared_ptr<MpiRequest> MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, int tag) const
{
return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
}
void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
{
TLLM_LOG_DEBUG("start MPI_Send with size %d", size);
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
TLLM_LOG_DEBUG("end MPI_Send with size %d", size);
}
void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const
{
send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
}
MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const
{
TLLM_LOG_DEBUG("start MPI_Recv with size %d", size);
MPI_Status status{};
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Recv, buffer, size, getMpiDtype(dtype), source, tag, mComm, &status);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
TLLM_LOG_DEBUG("end MPI_Recv with size %d", size);
return status;
}
MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const
{
return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag);
}
MpiComm MpiComm::split(int color, int key) const
{
MPI_Comm splitComm = nullptr;
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
return MpiComm{splitComm, true};
}
void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf,
std::vector<int> const& recvcounts, std::vector<int> const& displs, MpiType recvtype) const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(),
getMpiDtype(recvtype), mComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
bool MpiComm::improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
int flag{0};
MPICHECK(MPI_Improbe(source, tag, mComm, &flag, msg, status));
return flag != 0;
#else
TLLM_THROW("Multi device support is disabled.");
return false;
#endif
}
bool MpiComm::iprobe(int source, int tag, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
int flag{0};
MPICHECK(MPI_Iprobe(source, tag, mComm, &flag, status));
return flag != 0;
#else
TLLM_THROW("Multi device support is disabled.");
return false;
#endif
}
void MpiComm::recvPoll(int source, int tag, int periodMs) const
{
MPI_Status status;
while (!iprobe(source, tag, &status))
{
std::this_thread::sleep_for(std::chrono::milliseconds(periodMs));
}
}
int MpiComm::getRank() const
{
int rank = 0;
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Comm_rank(mComm, &rank));
#endif
return rank;
}
int MpiComm::getSize() const
{
int world_size = 1;
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Comm_size(mComm, &world_size));
#endif
return world_size;
}
MpiComm const& MpiComm::world()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
static MpiComm commWorld{MPI_COMM_WORLD, false};
initialize();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return commWorld;
}
MpiComm& MpiComm::mutableSession()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
static MpiComm commSession{MPI_COMM_WORLD, false};
initialize();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return commSession;
}
MpiComm& MpiComm::mutableLocalSession()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
static MpiComm localSession = initLocalSession();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return localSession;
}
void MpiComm::refreshLocalSession()
{
#if ENABLE_MULTI_DEVICE
static std::mutex mutex;
std::unique_lock lock(mutex);
auto initSessionRanks = getWorldRanks(MpiComm::session());
auto localSessionRanks = getWorldRanks(MpiComm::localSession());
// Add to intersectionRanks in order of initSessionRanks
std::vector<int> intersectionRanks;
std::unordered_set<int> localSessionRanksSet(localSessionRanks.begin(), localSessionRanks.end());
for (auto rank : initSessionRanks)
{
if (localSessionRanksSet.find(rank) != localSessionRanksSet.end())
{
intersectionRanks.push_back(rank);
}
}
MPI_Group worldGroup = nullptr;
MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
MPI_Group localGroup = nullptr;
MPICHECK(MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup));
MPI_Comm localComm = nullptr;
MPICHECK(MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm));
MpiComm::mutableLocalSession().mFreeComm = true;
MpiComm::mutableLocalSession() = MpiComm{localComm, false};
TLLM_LOG_INFO("Refreshed the MPI local session");
#endif // ENABLE_MULTI_DEVICE
}
MpiComm::MpiComm(MPI_Comm g, bool freeComm)
: mComm{g}
, mFreeComm{freeComm}
{
TLLM_CHECK(mComm != MPI_COMM_NULL);
}
MpiComm::~MpiComm() noexcept
{
#if ENABLE_MULTI_DEVICE
if (mFreeComm && mComm)
{
if (MPI_Comm_free(&mComm) != MPI_SUCCESS)
{
TLLM_LOG_ERROR("MPI_Comm_free failed");
}
}
#endif // ENABLE_MULTI_DEVICE
}
MpiComm::MpiComm(MpiComm&& comm) noexcept
: mComm{comm.mComm}
, mFreeComm{comm.mFreeComm}
{
comm.mFreeComm = false;
}
MpiComm& MpiComm::operator=(MpiComm&& comm) noexcept
{
this->~MpiComm();
mComm = comm.mComm;
mFreeComm = comm.mFreeComm;
comm.mFreeComm = false;
return *this;
}
MpiWaitThread::MpiWaitThread(std::string name, std::function<void()> funcWait, std::function<void()> funcSetup)
: mName{name.c_str()}
, mFuncWait{funcWait}
, mFuncSetup{funcSetup}
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
mThread = std::make_unique<std::thread>(&MpiWaitThread::sideThread, this);
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
MpiWaitThread::~MpiWaitThread()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
waitStop();
mShouldExit.store(true);
notifyStart();
mThread->join();
mThread.reset(nullptr);
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
void MpiWaitThread::sideThread()
{
if (mFuncSetup)
{
mFuncSetup();
}
while (!mShouldExit.load())
{
notifyStop();
waitStart();
mFuncWait();
}
}
void MpiWaitThread::waitStart()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
std::unique_lock<std::mutex> lock(mMutex);
mCondVar.wait(lock, [this] { return mRunning; });
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
void MpiWaitThread::waitStop()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
std::unique_lock<std::mutex> lock(mMutex);
mCondVar.wait(lock, [this] { return !mRunning; });
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
void MpiWaitThread::notifyStart()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
std::lock_guard<std::mutex> lock(mMutex);
mRunning = true;
mCondVar.notify_one();
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
void MpiWaitThread::notifyStop()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
std::lock_guard<std::mutex> lock(mMutex);
mRunning = false;
mCondVar.notify_one();
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
} // namespace tensorrt_llm::mpi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment