Unverified Commit 9602c2aa authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

keep the parts needed for moe_kernels (#3218)

parent e81d7f11
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
#
file(GLOB SRCS *.cpp)
file(GLOB CU_SRCS *.cu)
add_library(common_src OBJECT ${SRCS} ${CU_SRCS})
set_property(TARGET common_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET common_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
File mode changed from 100755 to 100644
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/common/tllmException.h"
#include <string>
namespace tensorrt_llm::common
{
[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "")
{
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str()));
}
} // namespace tensorrt_llm::common
class DebugConfig
{
public:
static bool isCheckDebugEnabled();
};
#if defined(_WIN32)
#define TLLM_LIKELY(x) (__assume((x) == 1), (x))
#define TLLM_UNLIKELY(x) (__assume((x) == 0), (x))
#else
#define TLLM_LIKELY(x) __builtin_expect((x), 1)
#define TLLM_UNLIKELY(x) __builtin_expect((x), 0)
#endif
#define TLLM_CHECK(val) \
do \
{ \
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
} while (0)
#define TLLM_CHECK_WITH_INFO(val, info, ...) \
do \
{ \
TLLM_LIKELY(static_cast<bool>(val)) \
? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError( \
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \
} while (0)
#define TLLM_CHECK_DEBUG(val) \
do \
{ \
if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
{ \
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
} \
} while (0)
#define TLLM_CHECK_DEBUG_WITH_INFO(val, info, ...) \
do \
{ \
if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
{ \
TLLM_LIKELY(static_cast<bool>(val)) \
? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError( \
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \
} \
} while (0)
#define TLLM_THROW(...) \
do \
{ \
throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \
} while (0)
#define TLLM_WRAP(ex) \
NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what())
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define CUDA_LIB_NAME "cuda"
#if defined(_WIN32)
#include <windows.h>
#define dllOpen(name) LoadLibrary("nv" name ".dll")
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
#define dllGetSym(handle, name) static_cast<void*>(GetProcAddress(static_cast<HMODULE>(handle), name))
#else // For non-Windows platforms
#include <dlfcn.h>
#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY)
#define dllClose(handle) dlclose(handle)
#define dllGetSym(handle, name) dlsym(handle, name)
#endif // defined(_WIN32)
#include "cudaDriverWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
namespace tensorrt_llm::common
{
std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
{
static std::mutex mutex;
static std::weak_ptr<CUDADriverWrapper> instance;
std::shared_ptr<CUDADriverWrapper> result = instance.lock();
if (result)
{
return result;
}
std::lock_guard<std::mutex> lock(mutex);
result = instance.lock();
if (!result)
{
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
instance = result;
}
return result;
}
CUDADriverWrapper::CUDADriverWrapper()
: handle(dllOpen(CUDA_LIB_NAME))
{
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
auto load_sym = [](void* handle, char const* name)
{
void* ret = dllGetSym(handle, name);
return ret;
};
*reinterpret_cast<void**>(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
*reinterpret_cast<void**>(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage");
*reinterpret_cast<void**>(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
*reinterpret_cast<void**>(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
*reinterpret_cast<void**>(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
*reinterpret_cast<void**>(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
*reinterpret_cast<void**>(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
*reinterpret_cast<void**>(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
*reinterpret_cast<void**>(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
*reinterpret_cast<void**>(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
*reinterpret_cast<void**>(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
*reinterpret_cast<void**>(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
*reinterpret_cast<void**>(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
*reinterpret_cast<void**>(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
*reinterpret_cast<void**>(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
*reinterpret_cast<void**>(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
}
CUDADriverWrapper::~CUDADriverWrapper()
{
dllClose(handle);
}
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const
{
return (*_cuGetErrorName)(error, pStr);
}
CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const
{
return (*_cuGetErrorMessage)(error, pStr);
}
CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const
{
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
}
CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const
{
return (*_cuLinkComplete)(state, cubinOut, sizeOut);
}
CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const
{
return (*_cuModuleUnload)(hmod);
}
CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const
{
return (*_cuLinkDestroy)(state);
}
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const
{
return (*_cuModuleLoadData)(module, image);
}
CUresult CUDADriverWrapper::cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const
{
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
}
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const
{
return (*_cuModuleGetFunction)(hfunc, hmod, name);
}
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const
{
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
}
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path,
unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const
{
return (*_cuLaunchCooperativeKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams);
}
CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const
{
return (*_cuLaunchKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
}
CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const
{
return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides,
boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill);
}
CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const
{
return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount);
}
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUDA_DRIVER_WRAPPER_H
#define CUDA_DRIVER_WRAPPER_H
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
#include <memory>
#include <mutex>
namespace tensorrt_llm::common
{
class CUDADriverWrapper
{
public:
static std::shared_ptr<CUDADriverWrapper> getInstance();
~CUDADriverWrapper();
CUDADriverWrapper(CUDADriverWrapper const&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete;
CUDADriverWrapper(CUDADriverWrapper&&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete;
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
CUresult cuGetErrorMessage(CUresult error, char const** pStr) const;
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
CUresult cuModuleUnload(CUmodule hmod) const;
CUresult cuLinkDestroy(CUlinkState state) const;
CUresult cuModuleLoadData(CUmodule* module, void const* image) const;
CUresult cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const;
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const;
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions,
CUjit_option* options, void** optionValues) const;
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name,
unsigned int numOptions, CUjit_option* options, void** optionValues) const;
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const;
CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra) const;
CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank,
void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim,
cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const;
CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const;
private:
void* handle;
CUDADriverWrapper();
CUresult (*_cuGetErrorName)(CUresult, char const**);
CUresult (*_cuGetErrorMessage)(CUresult, char const**);
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
CUresult (*_cuModuleUnload)(CUmodule);
CUresult (*_cuLinkDestroy)(CUlinkState);
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
CUresult (*_cuModuleLoadData)(CUmodule*, void const*);
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*);
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*);
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLinkAddData)(
CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, CUstream, void**);
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra);
CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);
CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount);
};
template <typename T>
void checkDriver(
T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line)
{
if (result)
{
char const* errorName = nullptr;
char const* errorMsg = nullptr;
wrap.cuGetErrorName(result, &errorName);
wrap.cuGetErrorMessage(result, &errorMsg);
throw TllmException(
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg));
}
}
} // namespace tensorrt_llm::common
/*
* Macros compliant with TensorRT coding conventions
*/
#define TLLM_CU_CHECK(stat) \
do \
{ \
tensorrt_llm::common::checkDriver( \
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
} while (0)
#endif // CUDA_DRIVER_WRAPPER_H
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_FP8
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <stdint.h>
#define FP8_MHA
#define FUSE_GEMM_ACT
#define FP8_GEMM_OUTPUT_QUANT_DISABLE
#ifdef FUSE_GEMM_ACT
#define USE_QGMMA
#endif
namespace tensorrt_llm
{
namespace common
{
constexpr float FP8_E4M3_MAX = 448.0f;
enum QuantizeMode
{
PER_CHANNEL,
PER_TENSOR,
PER_CHANNEL_WEIGHT_PER_TENSOR_ACT,
PER_TOKEN,
};
// Packed Data Type
typedef struct __CUDA_ALIGN__(32)
{
float array[8];
} float8;
typedef struct __CUDA_ALIGN__(16)
{
half array[8];
} half8;
typedef struct __CUDA_ALIGN__(8)
{
half2 array[2];
} half2_2;
typedef struct __CUDA_ALIGN__(8)
{
half array[4];
} half_4;
#ifdef ENABLE_BF16
typedef struct __CUDA_ALIGN__(4)
{
__nv_bfloat16 array[2];
} __nv_bfloat16_2;
typedef struct __CUDA_ALIGN__(8)
{
__nv_bfloat162 x, y;
} __nv_bfloat162_2_xy;
typedef struct __CUDA_ALIGN__(8)
{
__nv_bfloat16 array[4];
} __nv_bfloat164;
typedef struct __CUDA_ALIGN__(8)
{
__nv_bfloat162 array[2];
} __nv_bfloat162_2;
typedef struct __CUDA_ALIGN__(16)
{
__nv_bfloat16 array[8];
} __nv_bfloat168;
typedef struct __CUDA_ALIGN__(16)
{
__nv_bfloat162 array[4];
} __nv_bfloat162_4;
typedef struct __CUDA_ALIGN__(32)
{
__nv_bfloat16 array[16];
} __nv_bfloat1616;
#endif
#ifdef ENABLE_FP8
typedef struct __CUDA_ALIGN__(2)
{
__nv_fp8_e4m3 array[2];
} __nv_fp8_2_e4m3;
typedef struct __CUDA_ALIGN__(4)
{
__nv_fp8_e4m3 array[4];
} __nv_fp8_4_e4m3;
typedef struct __CUDA_ALIGN__(4)
{
__nv_fp8x2_e4m3 array[2];
} __nv_fp8x2_x2_e4m3;
typedef struct __CUDA_ALIGN__(8)
{
__nv_fp8_e4m3 array[8];
} __nv_fp8_8_e4m3;
typedef struct __CUDA_ALIGN__(8)
{
__nv_fp8x2_e4m3 array[4];
} __nv_fp8x2_x4_e4m3;
typedef struct __CUDA_ALIGN__(16)
{
__nv_fp8_e4m3 array[16];
} __nv_fp8x16_e4m3;
#endif
// only BF16 and FP8
template <typename T, int PACK_SIZE>
struct PackType
{
using type = float;
};
#ifdef ENABLE_BF16
template <>
struct PackType<__nv_bfloat16, 2>
{
using type = __nv_bfloat16_2;
};
template <>
struct PackType<__nv_bfloat16, 4>
{
using type = __nv_bfloat164;
};
template <>
struct PackType<__nv_bfloat16, 8>
{
using type = __nv_bfloat168;
};
#endif
#ifdef ENABLE_FP8
template <>
struct PackType<__nv_fp8_e4m3, 2>
{
using type = __nv_fp8_2_e4m3;
};
template <>
struct PackType<__nv_fp8_e4m3, 4>
{
using type = __nv_fp8_4_e4m3;
};
template <>
struct PackType<__nv_fp8_e4m3, 8>
{
using type = __nv_fp8_8_e4m3;
};
#endif
__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, __nv_fp8x4_e4m3 const* in)
{
const char4 tmp_val = reinterpret_cast<char4 const*>(in)[0];
*out1 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
*out2 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0],
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]);
}
__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(__nv_fp8x2_e4m3 const* in)
{
const char2 tmp_val = reinterpret_cast<char2 const*>(in)[0];
__nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
return out;
}
__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, __nv_fp8x4_e4m3 const* in)
{
const char4 tmp_val = reinterpret_cast<char4 const*>(in)[0];
*out1 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
*out2 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0],
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]);
}
__inline__ __device__ half2 fp8x2_e4m3_to_half2(__nv_fp8x2_e4m3 const* in)
{
const char2 tmp_val = reinterpret_cast<char2 const*>(in)[0];
half2 out = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
return out;
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeQuantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream);
template <typename T_OUT, typename T_S, typename T_IN>
void invokeDequantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream);
template <typename T_FAKE, typename T_OUT, typename T_IN>
void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream);
template <typename T_S, typename T_W>
void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t k, const int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream);
template <typename T_OUT, typename T_S, typename T_IN>
void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* weights, const int64_t numel,
const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream);
} // namespace common
} // namespace tensorrt_llm
#endif // ENABLE_FP8
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaProfilerUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/stringUtils.h"
#include <cstdint>
#include <optional>
namespace
{
std::tuple<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexesImpl(
std::string const& envVarName)
{
auto envVarVal = std::getenv(envVarName.c_str());
auto envVarValStr = std::string{envVarVal != nullptr ? envVarVal : ""};
auto values = tensorrt_llm::common::str2set(envVarValStr, ',');
std::unordered_set<int32_t> startSet;
std::unordered_set<int32_t> endSet;
for (std::string const& value : values)
{
size_t dashIdx = value.find("-");
if (dashIdx != std::string::npos)
{
int32_t start = std::stoi(value.substr(0, dashIdx));
startSet.insert(start);
int32_t end = std::stoi(value.substr(dashIdx + 1));
endSet.insert(end);
}
else
{
int32_t start_end = std::stoi(value);
startSet.insert(start_end);
endSet.insert(start_end);
}
}
return std::make_pair(startSet, endSet);
}
} // namespace
namespace tensorrt_llm::common
{
std::pair<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexes(
std::string const& envVarName, std::optional<std::string> const& legacyEnvVarName)
{
auto [profileIterIdxs, stopIterIdxs] = populateIterationIndexesImpl(envVarName);
// If empty, try to use legacy env var name
if (legacyEnvVarName && profileIterIdxs.empty() && stopIterIdxs.empty())
{
std::tie(profileIterIdxs, stopIterIdxs) = populateIterationIndexesImpl(legacyEnvVarName.value());
if (!profileIterIdxs.empty() || !stopIterIdxs.empty())
{
TLLM_LOG_WARNING(
"Using deprecated environment variable %s to specify cudaProfiler start and stop iterations. "
"Please "
"use %s "
"instead.",
legacyEnvVarName.value().c_str(), envVarName.c_str());
}
}
return std::make_pair(profileIterIdxs, stopIterIdxs);
}
} // namespace tensorrt_llm::common
This diff is collapsed.
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstddef>
namespace tensorrt_llm::utils::customAllReduceUtils
{
constexpr size_t NUM_POINTERS_PER_RANK = 7;
// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py
inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
{
if (worldSize <= 2)
{
return 16 * 1000 * 1000;
}
return 8 * 1000 * 1000;
}
} // namespace tensorrt_llm::utils::customAllReduceUtils
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "envUtils.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include <cstdlib>
namespace tensorrt_llm::common
{
std::optional<int32_t> getIntEnv(char const* name)
{
char const* const env = std::getenv(name);
if (env == nullptr)
{
return std::nullopt;
}
int32_t const val = std::stoi(env);
if (val <= 0)
{
return std::nullopt;
}
return {val};
};
// Returns true if the env variable exists and is set to "1"
static bool getBoolEnv(char const* name)
{
char const* env = std::getenv(name);
return env && env[0] == '1' && env[1] == '\0';
}
// XQA kernels (optimized kernels for generation phase).
bool forceXQAKernels()
{
static bool const forceXQA = (getIntEnv("TRTLLM_FORCE_XQA").value_or(0) != 0);
return forceXQA;
}
std::optional<bool> getEnvEnableXQAJIT()
{
static bool init = false;
static bool exists = false;
static bool enableXQAJIT = false;
if (!init)
{
init = true;
char const* enable_xqa_jit_var = std::getenv("TRTLLM_ENABLE_XQA_JIT");
if (enable_xqa_jit_var)
{
exists = true;
if (enable_xqa_jit_var[0] == '1' && enable_xqa_jit_var[1] == '\0')
{
enableXQAJIT = true;
}
}
}
if (exists)
{
return enableXQAJIT;
}
else
{
return std::nullopt;
}
}
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug()
{
static bool init = false;
static bool forceMmhaMaxSeqLenTile = false;
if (!init)
{
init = true;
char const* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG");
if (enable_mmha_debug_var)
{
if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0')
{
forceMmhaMaxSeqLenTile = true;
}
}
}
return forceMmhaMaxSeqLenTile;
}
int getEnvMmhaBlocksPerSequence()
{
static bool init = false;
static int mmhaBlocksPerSequence = 0;
if (!init)
{
init = true;
char const* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE");
if (mmhaBlocksPerSequenceEnv)
{
mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv);
if (mmhaBlocksPerSequence <= 0)
{
TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_BLOCKS_PER_SEQUENCE. Will use default values instead!");
}
}
}
return mmhaBlocksPerSequence;
}
int getEnvMmhaKernelBlockSize()
{
static bool init = false;
static int mmhaKernelBlockSize = 0;
if (!init)
{
init = true;
char const* mmhaKernelBlockSizeEnv = std::getenv("TRTLLM_MMHA_KERNEL_BLOCK_SIZE");
if (mmhaKernelBlockSizeEnv)
{
mmhaKernelBlockSize = std::atoi(mmhaKernelBlockSizeEnv);
if (mmhaKernelBlockSize <= 0)
{
TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_KERNEL_BLOCK_SIZE. Will use default values instead!");
}
}
}
return mmhaKernelBlockSize;
}
bool getEnvEnablePDL()
{
static bool init = false;
static bool enablePDL = false;
if (!init)
{
init = true;
// PDL only available when arch >= 90
if (getSMVersion() >= 90)
{
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
}
}
return enablePDL;
}
bool getEnvUseUCXKvCache()
{
static bool const useUCXKVCache = getBoolEnv("TRTLLM_USE_UCX_KVCACHE");
return useUCXKVCache;
}
std::string getEnvUCXInterface()
{
static bool init = false;
static std::string ucxInterface;
if (!init)
{
init = true;
{
char const* ucx_interface = std::getenv("TRTLLM_UCX_INTERFACE");
if (ucx_interface)
{
ucxInterface = ucx_interface;
}
}
}
return ucxInterface;
}
bool getEnvDisaggLayerwise()
{
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
return disaggLayerwise;
}
bool getEnvParallelCacheSend()
{
static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND");
return parallelCacheSend;
}
bool getEnvRequestKVCacheSerial()
{
static bool const requestKVCacheSerial = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_SERIAL");
return requestKVCacheSerial;
}
bool getEnvDisableKVCacheTransferOverlap()
{
static bool const disableKVCacheTransferOverlap = getBoolEnv("TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP");
return disableKVCacheTransferOverlap;
}
bool getEnvDisableReceiveKVCacheParallel()
{
static bool const disableReceiveParallel = getBoolEnv("TRTLLM_DISABLE_KVCACHE_RECEIVE_PARALLEL");
return disableReceiveParallel;
}
} // namespace tensorrt_llm::common
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <optional>
#include <string>
namespace tensorrt_llm::common
{
// Useful when you want to inject some debug code controllable with env var.
std::optional<int32_t> getIntEnv(char const* name);
// XQA kernels (optimized kernels for generation phase).
bool forceXQAKernels();
// Whether XQA JIT is enabled.
//
// Returns the value of TRTLLM_ENABLE_XQA_JIT env var. If such env var doesn't exist, std::nullopt is returned.
std::optional<bool> getEnvEnableXQAJIT();
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug();
int getEnvMmhaBlocksPerSequence();
int getEnvMmhaKernelBlockSize();
// Whether PDL is enabled.
bool getEnvEnablePDL();
bool getEnvUseUCXKvCache();
std::string getEnvUCXInterface();
bool getEnvDisaggLayerwise();
bool getEnvParallelCacheSend();
bool getEnvRequestKVCacheSerial();
bool getEnvDisableKVCacheTransferOverlap();
bool getEnvDisableReceiveKVCacheParallel();
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdlib>
#include <iostream>
#include <stdexcept>
#include <string>
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/stringUtils.h"
namespace tensorrt_llm::common
{
class Logger
{
// On Windows, the file wingdi.h is included which has
// #define ERROR 0
// This breaks everywhere ERROR is used in the Level enum
#ifdef _WIN32
#undef ERROR
#endif // _WIN32
public:
enum Level
{
TRACE = 0,
DEBUG = 10,
INFO = 20,
WARNING = 30,
ERROR = 40
};
static Logger* getLogger();
Logger(Logger const&) = delete;
void operator=(Logger const&) = delete;
#if defined(_MSC_VER)
template <typename... Args>
void log(Level level, char const* format, Args const&... args);
template <typename... Args>
void log(Level level, int rank, char const* format, Args const&... args);
#else
template <typename... Args>
void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0)));
template <typename... Args>
void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0)));
#endif
template <typename... Args>
void log(Level level, std::string const& format, Args const&... args)
{
return log(level, format.c_str(), args...);
}
template <typename... Args>
void log(Level const level, int const rank, std::string const& format, Args const&... args)
{
return log(level, rank, format.c_str(), args...);
}
void log(std::exception const& ex, Level level = Level::ERROR);
Level getLevel() const
{
return level_;
}
void setLevel(Level const level)
{
level_ = level;
log(INFO, "Set logger level to %s", getLevelName(level));
}
bool isEnabled(Level const level) const
{
return level_ <= level;
}
private:
static auto constexpr kPREFIX = "[TensorRT-LLM]";
#ifndef NDEBUG
Level const DEFAULT_LOG_LEVEL = DEBUG;
#else
Level const DEFAULT_LOG_LEVEL = INFO;
#endif
Level level_ = DEFAULT_LOG_LEVEL;
Logger(); // NOLINT(modernize-use-equals-delete)
static inline char const* getLevelName(Level const level)
{
switch (level)
{
case TRACE: return "TRACE";
case DEBUG: return "DEBUG";
case INFO: return "INFO";
case WARNING: return "WARNING";
case ERROR: return "ERROR";
}
TLLM_THROW("Unknown log level: %d", level);
}
static inline std::string getPrefix(Level const level)
{
return fmtstr("%s[%s] ", kPREFIX, getLevelName(level));
}
static inline std::string getPrefix(Level const level, int const rank)
{
return fmtstr("%s[%s][%d] ", kPREFIX, getLevelName(level), rank);
}
};
template <typename... Args>
void Logger::log(Logger::Level level, char const* format, Args const&... args)
{
if (isEnabled(level))
{
auto const fmt = getPrefix(level) + format;
auto& out = level_ < WARNING ? std::cout : std::cerr;
if constexpr (sizeof...(args) > 0)
{
out << fmtstr(fmt.c_str(), args...);
}
else
{
out << fmt;
}
out << std::endl;
}
}
template <typename... Args>
void Logger::log(Logger::Level const level, int const rank, char const* format, Args const&... args)
{
if (isEnabled(level))
{
auto const fmt = getPrefix(level, rank) + format;
auto& out = level_ < WARNING ? std::cout : std::cerr;
if constexpr (sizeof...(args) > 0)
{
out << fmtstr(fmt.c_str(), args...);
}
else
{
out << fmt;
}
out << std::endl;
}
}
#define TLLM_LOG(level, ...) \
do \
{ \
auto* const logger = tensorrt_llm::common::Logger::getLogger(); \
if (logger->isEnabled(level)) \
{ \
logger->log(level, __VA_ARGS__); \
} \
} while (0)
#define TLLM_LOG_TRACE(...) TLLM_LOG(tensorrt_llm::common::Logger::TRACE, __VA_ARGS__)
#define TLLM_LOG_DEBUG(...) TLLM_LOG(tensorrt_llm::common::Logger::DEBUG, __VA_ARGS__)
#define TLLM_LOG_INFO(...) TLLM_LOG(tensorrt_llm::common::Logger::INFO, __VA_ARGS__)
#define TLLM_LOG_WARNING(...) TLLM_LOG(tensorrt_llm::common::Logger::WARNING, __VA_ARGS__)
#define TLLM_LOG_ERROR(...) TLLM_LOG(tensorrt_llm::common::Logger::ERROR, __VA_ARGS__)
#define TLLM_LOG_EXCEPTION(ex, ...) tensorrt_llm::common::Logger::getLogger()->log(ex, ##__VA_ARGS__)
} // namespace tensorrt_llm::common
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime.h>
namespace tensorrt_llm
{
namespace common
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
inline __device__ __host__ T divUp(T m, T n)
{
return (m + n - 1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace common
} // namespace tensorrt_llm
This diff is collapsed.
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include <cassert>
namespace tensorrt_llm
{
namespace common
{
template <typename T>
void deviceMalloc(T** ptr, size_t size, bool is_random_initialize = true);
template <typename T>
void deviceMemSetZero(T* ptr, size_t size);
template <typename T>
void deviceFree(T*& ptr);
template <typename T>
void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0);
template <typename T>
void cudaD2Hcpy(T* tgt, T const* src, size_t const size);
template <typename T>
void cudaH2Dcpy(T* tgt, T const* src, size_t const size);
template <typename T>
void cudaD2Dcpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL);
template <typename T>
void cudaAutoCpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL);
template <typename T>
void cudaRandomUniform(T* buffer, size_t const size);
template <typename T>
int loadWeightFromBin(T* ptr, std::vector<size_t> shape, std::string filename,
TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32);
// template<typename T>
// int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr,
// T* scale_ptr,
// std::vector<size_t> shape,
// std::string filename,
// TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32);
void invokeCudaD2DcpyHalf2Float(float* dst, half* src, size_t const size, cudaStream_t stream);
void invokeCudaD2DcpyFloat2Half(half* dst, float* src, size_t const size, cudaStream_t stream);
#ifdef ENABLE_FP8
void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream);
void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream);
void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, size_t const size, cudaStream_t stream);
void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, size_t const size, cudaStream_t stream);
void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream);
#endif // ENABLE_FP8
#ifdef ENABLE_BF16
void invokeCudaD2DcpyBfloat2Float(float* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream);
#endif // ENABLE_BF16
template <typename T_OUT, typename T_IN>
void invokeCudaCast(T_OUT* dst, T_IN const* const src, size_t const size, cudaStream_t stream);
////////////////////////////////////////////////////////////////////////////////////////////////////
// The following functions implement conversion of multi-dimensional indices to an index in a flat array.
// The shape of the Tensor dimensions is passed as one array (`dims`), the indices are given as individual arguments.
// For examples on how to use these functions, see their tests `test_memory_utils.cu`.
// All of these functions can be evaluated at compile time by recursive template expansion.
template <typename TDim, typename T, typename TIndex>
__inline__ __host__ __device__ std::enable_if_t<std::is_pointer<TDim>::value, T> constexpr flat_index(
T const& acc, TDim dims, TIndex const& index)
{
assert(index < dims[0]);
return acc * dims[0] + index;
}
template <typename TDim, typename T, typename TIndex, typename... TIndices>
__inline__ __host__ __device__ std::enable_if_t<std::is_pointer<TDim>::value, T> constexpr flat_index(
T const& acc, TDim dims, TIndex const& index, TIndices... indices)
{
assert(index < dims[0]);
return flat_index(acc * dims[0] + index, dims + 1, indices...);
}
template <typename TDim, typename T>
__inline__ __host__ __device__ std::enable_if_t<std::is_pointer<TDim>::value, T> constexpr flat_index(
[[maybe_unused]] TDim dims, T const& index)
{
assert(index < dims[0]);
return index;
}
template <typename TDim, typename TIndex, typename... TIndices>
__inline__ __host__ __device__
std::enable_if_t<std::is_pointer<TDim>::value, typename std::remove_pointer<TDim>::type> constexpr flat_index(
TDim dims, TIndex const& index, TIndices... indices)
{
assert(index < dims[0]);
return flat_index(static_cast<typename std::remove_pointer<TDim>::type>(index), dims + 1, indices...);
}
template <unsigned skip = 0, typename T, std::size_t N, typename TIndex, typename... TIndices>
__inline__ __host__ __device__ T constexpr flat_index(
std::array<T, N> const& dims, TIndex const& index, TIndices... indices)
{
static_assert(skip < N);
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
return flat_index(&dims[skip], index, indices...);
}
template <unsigned skip = 0, typename T, typename TIndex, std::size_t N, typename... TIndices>
__inline__ __host__ __device__ T constexpr flat_index(
T const& acc, std::array<T, N> const& dims, TIndex const& index, TIndices... indices)
{
static_assert(skip < N);
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
return flat_index(acc, &dims[skip], index, indices...);
}
template <unsigned skip = 0, typename T, typename TIndex, std::size_t N, typename... TIndices>
__inline__ __host__ __device__ T constexpr flat_index(T const (&dims)[N], TIndex const& index, TIndices... indices)
{
static_assert(skip < N);
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
return flat_index(static_cast<T const*>(dims) + skip, index, indices...);
}
template <unsigned skip = 0, typename T, typename TIndex, std::size_t N, typename... TIndices>
__inline__ __host__ __device__ T constexpr flat_index(
T const& acc, T const (&dims)[N], TIndex const& index, TIndices... indices)
{
static_assert(skip < N);
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
return flat_index(acc, static_cast<T const*>(dims) + skip, index, indices...);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// These are simpler functions for multi-dimensional index conversion. Indices and dimensions are passed as individual
// arguments. These functions are more suitable for usage inside kernels than the corresponding flat_index functions
// which require arrays as arguments. Usage examples can be found in `test_memory_utils.cu`. The functions can be
// evaluated at compile time.
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index2(TIndex const& index_0, TIndex const& index_1, T const& dim_1)
{
assert(index_1 < dim_1);
return index_0 * dim_1 + index_1;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index3(
TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& dim_1, T const& dim_2)
{
assert(index_2 < dim_2);
return flat_index2(index_0, index_1, dim_1) * dim_2 + index_2;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index4(TIndex const& index_0, TIndex const& index_1,
TIndex const& index_2, TIndex const& index_3, T const& dim_1, T const& dim_2, T const& dim_3)
{
assert(index_3 < dim_3);
return flat_index3(index_0, index_1, index_2, dim_1, dim_2) * dim_3 + index_3;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index5(TIndex const& index_0, TIndex const& index_1,
TIndex const& index_2, TIndex const& index_3, TIndex const& index_4, T const& dim_1, T const& dim_2, T const& dim_3,
T const& dim_4)
{
assert(index_4 < dim_4);
return flat_index4(index_0, index_1, index_2, index_3, dim_1, dim_2, dim_3) * dim_4 + index_4;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index_strided3(
TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& stride_1, T const& stride_2)
{
assert(index_1 < stride_1 / stride_2);
assert(index_2 < stride_2);
return index_0 * stride_1 + index_1 * stride_2 + index_2;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index_strided4(TIndex const& index_0, TIndex const& index_1,
TIndex const& index_2, TIndex const& index_3, T const& stride_1, T const& stride_2, T const& stride_3)
{
assert(index_1 < stride_1 / stride_2);
assert(index_2 < stride_2 / stride_3);
assert(index_3 < stride_3);
return index_0 * stride_1 + index_1 * stride_2 + index_2 * stride_3 + index_3;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
void invokeInPlaceTranspose(T* data, T* workspace, size_t const dim0, size_t const dim1);
template <typename T>
void invokeInPlaceTranspose0213(
T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2, size_t const dim3);
template <typename T>
void invokeInPlaceTranspose102(T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2);
template <typename T>
void invokeMultiplyScale(T* tensor, float scale, size_t const size, cudaStream_t stream);
template <typename T>
void invokeDivideScale(T* tensor, float scale, size_t const size, cudaStream_t stream);
template <typename T_IN, typename T_OUT>
void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, size_t const size, cudaStream_t stream = 0);
template <typename T_IN, typename T_OUT>
void invokeCudaD2DScaleCpyConvert(
T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, size_t const size, cudaStream_t stream = 0);
inline bool checkIfFileExist(std::string const& file_path)
{
std::ifstream in(file_path, std::ios::in | std::ios::binary);
if (in.is_open())
{
in.close();
return true;
}
return false;
}
template <typename T>
void saveToBinary(T const* ptr, size_t const size, std::string filename);
template <typename T_IN, typename T_fake_type>
void invokeFakeCast(T_IN* input_ptr, size_t const size, cudaStream_t stream);
size_t cuda_datatype_size(TRTLLMCudaDataType dt);
template <typename T>
bool invokeCheckRange(T const* buffer, size_t const size, T min, T max, bool* d_within_range, cudaStream_t stream);
constexpr size_t DEFAULT_ALIGN_BYTES = 256;
size_t calcAlignedSize(std::vector<size_t> const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES);
void calcAlignedPointers(std::vector<void*>& outPtrs, void const* p, std::vector<size_t> const& sizes,
size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES);
struct AlignedPointersUnpacker
{
template <typename... T>
void operator()(T*&... outPtrs)
{
assert(sizeof...(T) == alignedPointers.size());
auto it = alignedPointers.begin();
((outPtrs = static_cast<T*>(*it++)), ...);
}
std::vector<void*> alignedPointers;
};
AlignedPointersUnpacker inline calcAlignedPointers(
void const* p, std::vector<size_t> const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES)
{
AlignedPointersUnpacker unpacker{};
calcAlignedPointers(unpacker.alignedPointers, p, sizes, ALIGN_BYTES);
return unpacker;
}
} // namespace common
} // namespace tensorrt_llm
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <numeric>
#include <unordered_set>
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include <csignal>
#include <cstdlib>
#include <mutex>
#include <thread>
#include <type_traits>
#ifndef _WIN32
#include <unistd.h>
#endif
// We rely on SizeType32 being int32_t in some places with weak type checking,
// i.e. we're passing void ptr to some function. To prevent mysterious errors
// in the future, we trigger a compilation error here if SizeType32 isn't int32_t.
static_assert(std::is_same<tensorrt_llm::runtime::SizeType32, std::int32_t>::value);
namespace tensorrt_llm::mpi
{
MPI_Datatype getMpiDtype(MpiType dtype)
{
#if ENABLE_MULTI_DEVICE
static std::unordered_map<MpiType, MPI_Datatype> const dtype_map{
{MpiType::kBYTE, MPI_BYTE},
{MpiType::kHALF, MPI_UINT16_T},
{MpiType::kFLOAT, MPI_FLOAT},
{MpiType::kDOUBLE, MPI_DOUBLE},
{MpiType::kBOOL, MPI_C_BOOL},
{MpiType::kINT8, MPI_INT8_T},
{MpiType::kUINT8, MPI_UINT8_T},
{MpiType::kINT32, MPI_INT32_T},
{MpiType::kUINT32, MPI_UINT32_T},
{MpiType::kINT64, MPI_INT64_T},
{MpiType::kUINT64, MPI_UINT64_T},
{MpiType::kFP8, MPI_UINT8_T},
{MpiType::kBF16, MPI_UINT16_T},
{MpiType::kCHAR, MPI_CHAR},
};
return dtype_map.at(dtype);
#else
TLLM_THROW("Multi device support is disabled.");
#endif
}
MPI_Op getMpiOp(MpiOp op)
{
#if ENABLE_MULTI_DEVICE
static std::unordered_map<MpiOp, MPI_Op> const op_map{
{MpiOp::NULLOP, MPI_OP_NULL},
{MpiOp::MAX, MPI_MAX},
{MpiOp::MIN, MPI_MIN},
{MpiOp::SUM, MPI_SUM},
{MpiOp::PROD, MPI_PROD},
{MpiOp::LAND, MPI_LAND},
{MpiOp::BAND, MPI_BAND},
{MpiOp::LOR, MPI_LOR},
{MpiOp::BOR, MPI_BOR},
{MpiOp::LXOR, MPI_LXOR},
{MpiOp::BXOR, MPI_BXOR},
{MpiOp::MINLOC, MPI_MINLOC},
{MpiOp::MAXLOC, MPI_MAXLOC},
{MpiOp::REPLACE, MPI_REPLACE},
};
return op_map.at(op);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
namespace
{
bool mpiInitialized = false;
std::recursive_mutex mpiMutex;
MpiComm initLocalSession()
{
#if ENABLE_MULTI_DEVICE
MPI_Comm localComm = nullptr;
MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, &localComm);
MpiComm localSession{localComm, false};
#else
MpiComm localSession{COMM_SESSION, false};
#endif // ENABLE_MULTI_DEVICE
return localSession;
}
} // namespace
std::vector<int> getWorldRanks(MpiComm const& comm)
{
#if ENABLE_MULTI_DEVICE
MPI_Group group = nullptr;
MPI_Group worldGroup = nullptr;
MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
MPICHECK(MPI_Comm_group(comm, &group));
int groupSize = 0;
MPICHECK(MPI_Group_size(group, &groupSize));
std::vector<int> ranks(groupSize);
std::vector<int> worldRanks(groupSize);
std::iota(ranks.begin(), ranks.end(), 0);
MPICHECK(MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data()));
MPICHECK(MPI_Group_free(&group));
MPICHECK(MPI_Group_free(&worldGroup));
#else
std::vector<int> worldRanks{0};
#endif
return worldRanks;
}
void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent)
{
// double-checked locking
if (mpiInitialized)
{
return;
}
std::lock_guard<std::recursive_mutex> lk(mpiMutex);
if (mpiInitialized)
{
return;
}
#if ENABLE_MULTI_DEVICE
int initialized = 0;
TLLM_MPI_CHECK(MPI_Initialized(&initialized));
if (!initialized)
{
TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode);
int providedMode = 0;
auto requiredMode = static_cast<int>(threadMode);
MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode));
TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed");
std::atexit([]() { MPI_Finalize(); });
/*
* We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause one of these 2
* signals. Signals like SIGINT and SIGTERM should be issued to the parent and should terminate MPI workers
* correctly.
*/
for (int sig : {SIGABRT, SIGSEGV})
{
__sighandler_t previousHandler = nullptr;
if (forwardAbortToParent)
{
previousHandler = std::signal(sig,
[](int signal)
{
#ifndef _WIN32
pid_t parentProcessId = getppid();
kill(parentProcessId, SIGKILL);
#endif
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
});
}
else
{
previousHandler = std::signal(sig, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); });
}
TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed");
}
// ensure local MPI communicator is initialized
MpiComm::localSession();
TLLM_LOG_INFO("Initialized MPI");
}
#endif // ENABLE_MULTI_DEVICE
mpiInitialized = true;
}
void MpiComm::barrier() const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Barrier(mComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
#if ENABLE_MULTI_DEVICE
template <typename TMpiFunc, typename TBase, typename... TArgs,
typename = std::enable_if_t<std::is_same_v<void, std::remove_const_t<TBase>>>>
size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dtype, TArgs... args)
{
constexpr auto maxP1 = static_cast<size_t>(std::numeric_limits<int>::max()) + 1;
if (TLLM_LIKELY(size < maxP1))
{
MPICHECK(func(buffer, size, dtype, args...));
return 1;
}
constexpr size_t alignment = 256;
int elementSize = 1;
MPICHECK(MPI_Type_size(dtype, &elementSize));
elementSize = std::min<int>(elementSize, alignment);
// We cap at max alignment-bytes chunks that can be sent at once.
auto const step = maxP1 - (alignment / elementSize);
using TCast = std::conditional_t<std::is_const_v<TBase>, uint8_t const, uint8_t>;
size_t count = 0;
while (size != 0)
{
auto currentStep = static_cast<int>(std::min(size, step));
MPICHECK(func(buffer, currentStep, dtype, args...));
size -= currentStep;
size_t diff = static_cast<size_t>(currentStep) * elementSize;
buffer = static_cast<TCast*>(buffer) + diff;
++count;
}
return count;
}
#endif // ENABLE_MULTI_DEVICE
std::shared_ptr<MpiRequest> MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const
{
std::shared_ptr<MpiRequest> r = std::make_shared<MpiRequest>();
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
return r;
}
std::shared_ptr<MpiRequest> MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const
{
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}
void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
{
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
void MpiComm::bcast(runtime::IBuffer& buf, int root) const
{
bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}
std::shared_ptr<MpiRequest> MpiComm::sendAsync(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
{
TLLM_LOG_DEBUG("start MPI_Isend with size %d", size);
std::shared_ptr<MpiRequest> r = std::make_shared<MpiRequest>();
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Isend, buffer, size, getMpiDtype(dtype), dest, tag, mComm, &r->mRequest);
#else
TLLM_THROW("Multi device support is disabled.");
#endif
TLLM_LOG_DEBUG("end MPI_Isend with size %d", size);
return r;
}
std::shared_ptr<MpiRequest> MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, int tag) const
{
return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
}
void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
{
TLLM_LOG_DEBUG("start MPI_Send with size %d", size);
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
TLLM_LOG_DEBUG("end MPI_Send with size %d", size);
}
void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const
{
send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
}
MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const
{
TLLM_LOG_DEBUG("start MPI_Recv with size %d", size);
MPI_Status status{};
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Recv, buffer, size, getMpiDtype(dtype), source, tag, mComm, &status);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
TLLM_LOG_DEBUG("end MPI_Recv with size %d", size);
return status;
}
MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const
{
return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag);
}
MpiComm MpiComm::split(int color, int key) const
{
MPI_Comm splitComm = nullptr;
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
return MpiComm{splitComm, true};
}
void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf,
std::vector<int> const& recvcounts, std::vector<int> const& displs, MpiType recvtype) const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(),
getMpiDtype(recvtype), mComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
bool MpiComm::improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
int flag{0};
MPICHECK(MPI_Improbe(source, tag, mComm, &flag, msg, status));
return flag != 0;
#else
TLLM_THROW("Multi device support is disabled.");
return false;
#endif
}
bool MpiComm::iprobe(int source, int tag, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
int flag{0};
MPICHECK(MPI_Iprobe(source, tag, mComm, &flag, status));
return flag != 0;
#else
TLLM_THROW("Multi device support is disabled.");
return false;
#endif
}
void MpiComm::recvPoll(int source, int tag, int periodMs) const
{
MPI_Status status;
while (!iprobe(source, tag, &status))
{
std::this_thread::sleep_for(std::chrono::milliseconds(periodMs));
}
}
int MpiComm::getRank() const
{
int rank = 0;
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Comm_rank(mComm, &rank));
#endif
return rank;
}
int MpiComm::getSize() const
{
int world_size = 1;
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Comm_size(mComm, &world_size));
#endif
return world_size;
}
MpiComm const& MpiComm::world()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
static MpiComm commWorld{MPI_COMM_WORLD, false};
initialize();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return commWorld;
}
MpiComm& MpiComm::mutableSession()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
static MpiComm commSession{MPI_COMM_WORLD, false};
initialize();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return commSession;
}
MpiComm& MpiComm::mutableLocalSession()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
static MpiComm localSession = initLocalSession();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return localSession;
}
void MpiComm::refreshLocalSession()
{
#if ENABLE_MULTI_DEVICE
static std::mutex mutex;
std::unique_lock lock(mutex);
auto initSessionRanks = getWorldRanks(MpiComm::session());
auto localSessionRanks = getWorldRanks(MpiComm::localSession());
// Add to intersectionRanks in order of initSessionRanks
std::vector<int> intersectionRanks;
std::unordered_set<int> localSessionRanksSet(localSessionRanks.begin(), localSessionRanks.end());
for (auto rank : initSessionRanks)
{
if (localSessionRanksSet.find(rank) != localSessionRanksSet.end())
{
intersectionRanks.push_back(rank);
}
}
MPI_Group worldGroup = nullptr;
MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
MPI_Group localGroup = nullptr;
MPICHECK(MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup));
MPI_Comm localComm = nullptr;
MPICHECK(MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm));
MpiComm::mutableLocalSession().mFreeComm = true;
MpiComm::mutableLocalSession() = MpiComm{localComm, false};
TLLM_LOG_INFO("Refreshed the MPI local session");
#endif // ENABLE_MULTI_DEVICE
}
MpiComm::MpiComm(MPI_Comm g, bool freeComm)
: mComm{g}
, mFreeComm{freeComm}
{
TLLM_CHECK(mComm != MPI_COMM_NULL);
}
MpiComm::~MpiComm() noexcept
{
#if ENABLE_MULTI_DEVICE
if (mFreeComm && mComm)
{
if (MPI_Comm_free(&mComm) != MPI_SUCCESS)
{
TLLM_LOG_ERROR("MPI_Comm_free failed");
}
}
#endif // ENABLE_MULTI_DEVICE
}
MpiComm::MpiComm(MpiComm&& comm) noexcept
: mComm{comm.mComm}
, mFreeComm{comm.mFreeComm}
{
comm.mFreeComm = false;
}
MpiComm& MpiComm::operator=(MpiComm&& comm) noexcept
{
this->~MpiComm();
mComm = comm.mComm;
mFreeComm = comm.mFreeComm;
comm.mFreeComm = false;
return *this;
}
MpiWaitThread::MpiWaitThread(std::string name, std::function<void()> funcWait, std::function<void()> funcSetup)
: mName{name.c_str()}
, mFuncWait{funcWait}
, mFuncSetup{funcSetup}
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
mThread = std::make_unique<std::thread>(&MpiWaitThread::sideThread, this);
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
MpiWaitThread::~MpiWaitThread()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
waitStop();
mShouldExit.store(true);
notifyStart();
mThread->join();
mThread.reset(nullptr);
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
void MpiWaitThread::sideThread()
{
if (mFuncSetup)
{
mFuncSetup();
}
while (!mShouldExit.load())
{
notifyStop();
waitStart();
mFuncWait();
}
}
void MpiWaitThread::waitStart()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
std::unique_lock<std::mutex> lock(mMutex);
mCondVar.wait(lock, [this] { return mRunning; });
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
void MpiWaitThread::waitStop()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
std::unique_lock<std::mutex> lock(mMutex);
mCondVar.wait(lock, [this] { return !mRunning; });
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
void MpiWaitThread::notifyStart()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
std::lock_guard<std::mutex> lock(mMutex);
mRunning = true;
mCondVar.notify_one();
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
void MpiWaitThread::notifyStop()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
std::lock_guard<std::mutex> lock(mMutex);
mRunning = false;
mCondVar.notify_one();
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
} // namespace tensorrt_llm::mpi
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nvtx3/nvtx3.hpp>
#include <array>
namespace tensorrt_llm::common::nvtx
{
inline nvtx3::color nextColor()
{
#ifndef NVTX_DISABLE
constexpr std::array kColors{nvtx3::color{0xff00ff00}, nvtx3::color{0xff0000ff}, nvtx3::color{0xffffff00},
nvtx3::color{0xffff00ff}, nvtx3::color{0xff00ffff}, nvtx3::color{0xffff0000}, nvtx3::color{0xffffffff}};
constexpr auto numColors = kColors.size();
static thread_local std::size_t colorId = 0;
auto const color = kColors[colorId];
colorId = colorId + 1 >= numColors ? 0 : colorId + 1;
return color;
#else
return nvtx3::color{0};
#endif
}
} // namespace tensorrt_llm::common::nvtx
#define NVTX3_SCOPED_RANGE_WITH_NAME(range, name) \
::nvtx3::scoped_range range(::tensorrt_llm::common::nvtx::nextColor(), name)
#define NVTX3_SCOPED_RANGE(range) NVTX3_SCOPED_RANGE_WITH_NAME(range##_range, #range)
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "cuda.h"
#include <cstdint>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <functional>
#include <mutex>
#include <thread>
#ifdef _MSC_VER
#define FN_NAME __FUNCTION__
#else
#define FN_NAME __func__
#endif
#if ENABLE_MULTI_DEVICE
std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap()
{
static std::unordered_map<nvinfer1::DataType, ncclDataType_t> dtypeMap = {{nvinfer1::DataType::kFLOAT, ncclFloat32},
{nvinfer1::DataType::kHALF, ncclFloat16}, {nvinfer1::DataType::kBF16, ncclBfloat16}};
return &dtypeMap;
}
namespace
{
// Get NCCL unique ID for a group of ranks.
ncclUniqueId getUniqueId(std::set<int> const& group) noexcept
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
ncclUniqueId id;
if (rank == *group.begin())
{
NCCLCHECK(ncclGetUniqueId(&id));
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
COMM_SESSION.sendValue(id, *it, 0);
}
}
else
{
COMM_SESSION.recvValue(id, *group.begin(), 0);
}
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return id;
}
} // namespace
std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group)
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
static std::map<std::set<int>, std::shared_ptr<ncclComm_t>> commMap;
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
std::ostringstream oss;
int index = 0;
for (auto const& rank : group)
{
if (index != 0)
{
oss << ",";
}
oss << rank;
index++;
}
auto groupStr = oss.str();
auto it = commMap.find(group);
if (it != commMap.end())
{
auto ncclComm = it->second;
TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank);
return ncclComm;
}
TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank);
ncclUniqueId id = getUniqueId(group);
int groupRank = 0;
for (auto const& currentRank : group)
{
if (rank == currentRank)
break;
++groupRank;
}
TLLM_CHECK(groupRank < group.size());
std::shared_ptr<ncclComm_t> ncclComm(new ncclComm_t,
[](ncclComm_t* comm)
{
ncclCommDestroy(*comm);
delete comm;
});
NCCLCHECK(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank));
commMap[group] = ncclComm;
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return ncclComm;
}
#endif // ENABLE_MULTI_DEVICE
void const* tensorrt_llm::common::getCommSessionHandle()
{
#if ENABLE_MULTI_DEVICE
return &COMM_SESSION;
#else
return nullptr;
#endif // ENABLE_MULTI_DEVICE
}
namespace
{
// Get current cuda context, a default context will be created if there is no context.
inline CUcontext getCurrentCudaCtx()
{
CUcontext ctx{};
CUresult err = cuCtxGetCurrent(&ctx);
if (err == CUDA_ERROR_NOT_INITIALIZED || ctx == nullptr)
{
TLLM_CUDA_CHECK(cudaFree(nullptr));
err = cuCtxGetCurrent(&ctx);
}
TLLM_CHECK(err == CUDA_SUCCESS);
return ctx;
}
// Helper to create per-cuda-context singleton managed by std::shared_ptr.
// Unlike conventional singletons, singleton created with this will be released
// when not needed, instead of on process exit.
// Objects of this class shall always be declared static / global, and shall never own CUDA
// resources.
template <typename T>
class PerCudaCtxSingletonCreator
{
public:
using CreatorFunc = std::function<std::unique_ptr<T>()>;
using DeleterFunc = std::function<void(T*)>;
// creator returning std::unique_ptr is by design.
// It forces separation of memory for T and memory for control blocks.
// So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released.
// creator itself must not own CUDA resources. Only the object it creates can.
PerCudaCtxSingletonCreator(CreatorFunc creator, DeleterFunc deleter)
: mCreator{std::move(creator)}
, mDeleter{std::move(deleter)}
{
}
std::shared_ptr<T> operator()()
{
std::lock_guard<std::mutex> lk{mMutex};
CUcontext ctx{getCurrentCudaCtx()};
std::shared_ptr<T> result = mObservers[ctx].lock();
if (result == nullptr)
{
// Create the resource and register with an observer.
result = std::shared_ptr<T>{mCreator().release(),
[this, ctx](T* obj)
{
if (obj == nullptr)
{
return;
}
mDeleter(obj);
// Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
// frequently.
std::shared_ptr<T> observedObjHolder; // Delay destroy to avoid dead lock.
std::lock_guard<std::mutex> lk{mMutex};
// Must check observer again because another thread may created new instance for this ctx just
// before we lock mMutex. We can't infer that the observer is stale from the fact that obj is
// destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic
// operation, and the observer may be changed to observe another instance.
observedObjHolder = mObservers.at(ctx).lock();
if (observedObjHolder == nullptr)
{
mObservers.erase(ctx);
}
}};
mObservers.at(ctx) = result;
}
return result;
}
private:
CreatorFunc mCreator;
DeleterFunc mDeleter;
mutable std::mutex mMutex;
// CUDA resources are per-context.
std::unordered_map<CUcontext, std::weak_ptr<T>> mObservers;
};
template <typename T>
class PerThreadSingletonCreator
{
public:
using CreatorFunc = std::function<std::unique_ptr<T>()>;
using DeleterFunc = std::function<void(T*)>;
// creator returning std::unique_ptr is by design.
// It forces separation of memory for T and memory for control blocks.
// So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released.
// creator itself must not own CUDA resources. Only the object it creates can.
PerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter)
: mCreator{std::move(creator)}
, mDeleter{std::move(deleter)}
{
}
std::shared_ptr<T> operator()()
{
std::lock_guard<std::mutex> lk{mMutex};
std::thread::id thread = std::this_thread::get_id();
std::shared_ptr<T> result = mObservers[thread].lock();
if (result == nullptr)
{
// Create the resource and register with an observer.
result = std::shared_ptr<T>{mCreator().release(),
[this, thread](T* obj)
{
if (obj == nullptr)
{
return;
}
mDeleter(obj);
// Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
// frequently.
std::shared_ptr<T> observedObjHolder; // Delay destroy to avoid dead lock.
std::lock_guard<std::mutex> lk{mMutex};
// Must check observer again because another thread may created new instance for this ctx just
// before we lock mMutex. We can't infer that the observer is stale from the fact that obj is
// destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic
// operation, and the observer may be changed to observe another instance.
observedObjHolder = mObservers.at(thread).lock();
if (observedObjHolder == nullptr)
{
mObservers.erase(thread);
}
}};
mObservers.at(thread) = result;
}
return result;
}
private:
CreatorFunc mCreator;
DeleterFunc mDeleter;
mutable std::mutex mMutex;
// CUDA resources are per-thread.
std::unordered_map<std::thread::id, std::weak_ptr<T>> mObservers;
};
} // namespace
std::shared_ptr<cublasHandle_t> getCublasHandle()
{
static PerThreadSingletonCreator<cublasHandle_t> creator(
[]() -> auto
{
auto handle = std::unique_ptr<cublasHandle_t>(new cublasHandle_t);
TLLM_CUDA_CHECK(cublasCreate(handle.get()));
return handle;
},
[](cublasHandle_t* handle)
{
TLLM_CUDA_CHECK(cublasDestroy(*handle));
delete handle;
});
return creator();
}
std::shared_ptr<cublasLtHandle_t> getCublasLtHandle()
{
static PerThreadSingletonCreator<cublasLtHandle_t> creator(
[]() -> auto
{
auto handle = std::unique_ptr<cublasLtHandle_t>(new cublasLtHandle_t);
TLLM_CUDA_CHECK(cublasLtCreate(handle.get()));
return handle;
},
[](cublasLtHandle_t* handle)
{
TLLM_CUDA_CHECK(cublasLtDestroy(*handle));
delete handle;
});
return creator();
}
std::shared_ptr<tensorrt_llm::common::CublasMMWrapper> getCublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace)
{
static PerThreadSingletonCreator<tensorrt_llm::common::CublasMMWrapper> creator(
[cublasHandle, cublasltHandle, stream, workspace]() -> auto
{
auto wrapper = std::unique_ptr<tensorrt_llm::common::CublasMMWrapper>(
new tensorrt_llm::common::CublasMMWrapper(cublasHandle, cublasltHandle, stream, workspace));
return wrapper;
},
[](tensorrt_llm::common::CublasMMWrapper* wrapper) { delete wrapper; });
return creator();
}
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/workspace.h"
#include <NvInferRuntime.h>
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#if ENABLE_MULTI_DEVICE
#include <nccl.h>
#endif // ENABLE_MULTI_DEVICE
#include <cstring>
#include <map>
#include <memory>
#include <nvml.h>
#include <optional>
#include <set>
#include <string>
#include <unordered_map>
namespace tensorrt_llm::common
{
// Write values into buffer
template <typename T>
void write(char*& buffer, T const& val)
{
std::memcpy(buffer, &val, sizeof(T));
buffer += sizeof(T);
}
// Read values from buffer
template <typename T>
void read(char const*& buffer, T& val)
{
std::memcpy(&val, buffer, sizeof(T));
buffer += sizeof(T);
}
// Like std::unique_ptr, but does not prevent generation of default copy constructor when used as class members.
// The copy constructor produces nullptr. So the plugin default copy constructor will not really copy this, and
// your clone() implementation is responsible for initializing such data members.
// With this we can simplify clone() implementation when there are many data members including at least one unique_ptr.
template <typename T, typename Del = std::default_delete<T>>
class UniqPtrWNullCopy : public std::unique_ptr<T, Del>
{
public:
using std::unique_ptr<T, Del>::unique_ptr;
// for compatibility with std::make_unique
explicit UniqPtrWNullCopy(std::unique_ptr<T, Del>&& src)
: std::unique_ptr<T, Del>::unique_ptr{std::move(src)}
{
}
// copy constructor produces nullptr
UniqPtrWNullCopy(UniqPtrWNullCopy const&)
: std::unique_ptr<T, Del>::unique_ptr{}
{
}
};
// for testing only
void const* getCommSessionHandle();
} // namespace tensorrt_llm::common
inline bool isBuilding()
{
auto constexpr key = "IS_BUILDING";
auto const val = getenv(key);
return val != nullptr && std::string(val) == "1";
}
#if ENABLE_MULTI_DEVICE
#define NCCLCHECK(cmd) \
do \
{ \
ncclResult_t r = cmd; \
if (r != ncclSuccess) \
{ \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap();
std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group);
#endif // ENABLE_MULTI_DEVICE
//! To save GPU memory, all the plugins share the same cublas and cublasLt handle globally.
//! Get cublas and cublasLt handle for current cuda context
std::shared_ptr<cublasHandle_t> getCublasHandle();
std::shared_ptr<cublasLtHandle_t> getCublasLtHandle();
std::shared_ptr<tensorrt_llm::common::CublasMMWrapper> getCublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace);
#ifndef DEBUG
#define PLUGIN_CHECK(status) \
do \
{ \
if (status != 0) \
abort(); \
} while (0)
#define ASSERT_PARAM(exp) \
do \
{ \
if (!(exp)) \
return STATUS_BAD_PARAM; \
} while (0)
#define ASSERT_FAILURE(exp) \
do \
{ \
if (!(exp)) \
return STATUS_FAILURE; \
} while (0)
#define CSC(call, err) \
do \
{ \
cudaError_t cudaStatus = call; \
if (cudaStatus != cudaSuccess) \
{ \
return err; \
} \
} while (0)
#define DEBUG_PRINTF(...) \
do \
{ \
} while (0)
#else
#define ASSERT_PARAM(exp) \
do \
{ \
if (!(exp)) \
{ \
fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \
return STATUS_BAD_PARAM; \
} \
} while (0)
#define ASSERT_FAILURE(exp) \
do \
{ \
if (!(exp)) \
{ \
fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \
return STATUS_FAILURE; \
} \
} while (0)
#define CSC(call, err) \
do \
{ \
cudaError_t cudaStatus = call; \
if (cudaStatus != cudaSuccess) \
{ \
printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \
return err; \
} \
} while (0)
#define PLUGIN_CHECK(status) \
{ \
if (status != 0) \
{ \
DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \
abort(); \
} \
}
#define DEBUG_PRINTF(...) \
do \
{ \
printf(__VA_ARGS__); \
} while (0)
#endif // DEBUG
#define NVML_CHECK(cmd) \
do \
{ \
nvmlReturn_t r = cmd; \
if (r != NVML_SUCCESS) \
{ \
printf("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
/*
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <optional>
#include <string>
namespace tensorrt_llm
{
namespace common
{
class QuantMode
{
// [WARNING] KEEP BELOW DEFINITION IN SYNC WITH tensorrt_llm/quantization/mode.py
public:
using BaseType = std::uint32_t;
explicit constexpr QuantMode(BaseType value) noexcept
: mValue{value}
{
}
QuantMode() noexcept = default;
constexpr QuantMode(QuantMode const&) noexcept = default;
constexpr QuantMode& operator=(QuantMode const& other) noexcept = default;
static constexpr QuantMode none() noexcept
{
return QuantMode(BaseType(0));
}
static constexpr QuantMode int4Weights() noexcept
{
return QuantMode(BaseType(1u) << 0);
}
static constexpr QuantMode int8Weights() noexcept
{
return QuantMode(BaseType(1u) << 1);
}
static constexpr QuantMode activations() noexcept
{
return QuantMode(BaseType(1u) << 2);
}
static constexpr QuantMode perChannelScaling() noexcept
{
return QuantMode(BaseType(1u) << 3);
}
static constexpr QuantMode perTokenScaling() noexcept
{
return QuantMode(BaseType(1u) << 4);
}
static constexpr QuantMode perGroupScaling() noexcept
{
return QuantMode(BaseType(1u) << 5);
}
static constexpr QuantMode int8KvCache() noexcept
{
return QuantMode(BaseType(1u) << 6);
}
static constexpr QuantMode fp8KvCache() noexcept
{
return QuantMode(BaseType(1u) << 7);
}
static constexpr QuantMode fp8Qdq() noexcept
{
return QuantMode(BaseType(1u) << 8);
}
static constexpr QuantMode fp8RowWise() noexcept
{
return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9);
}
static constexpr QuantMode w4a8QServe() noexcept
{
return QuantMode(BaseType(1u) << 10);
}
constexpr BaseType value() const noexcept
{
return mValue;
}
constexpr bool isSet(QuantMode const& mode) const noexcept
{
return (mValue & mode.value()) == mode.value();
}
constexpr bool hasInt4Weights() const noexcept
{
return isSet(int4Weights());
}
constexpr bool hasInt8Weights() const noexcept
{
return isSet(int8Weights());
}
constexpr bool hasActivations() const noexcept
{
return isSet(activations());
}
constexpr bool hasPerChannelScaling() const noexcept
{
return isSet(perChannelScaling());
}
constexpr bool hasPerTokenScaling() const noexcept
{
return isSet(perTokenScaling());
}
constexpr bool hasPerGroupScaling() const noexcept
{
return isSet(perGroupScaling());
}
constexpr bool hasStaticActivationScaling() const noexcept
{
return !hasPerTokenScaling();
}
constexpr bool hasInt8KvCache() const noexcept
{
return isSet(int8KvCache());
}
constexpr bool hasFp8KvCache() const noexcept
{
return isSet(fp8KvCache());
}
constexpr bool hasFp8Qdq() const noexcept
{
return isSet(fp8Qdq());
}
constexpr bool hasFp8RowWise() const noexcept
{
return isSet(fp8RowWise());
}
constexpr bool hasKvCacheQuant() const noexcept
{
return hasInt8KvCache() || hasFp8KvCache();
}
static constexpr QuantMode fromDescription(bool quantizeWeights = false, bool quantizeActivations = false,
bool perToken = false, bool perChannel = false, bool perGroup = false, bool useInt4Weights = false,
bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false,
bool useW4a8QServe = false)
{
QuantMode quantMode{};
if (quantizeWeights)
{
if (useInt4Weights)
quantMode += int4Weights();
else
quantMode += int8Weights();
}
if (quantizeActivations)
{
quantMode += activations();
}
if (perChannel)
{
quantMode += QuantMode::perChannelScaling();
}
if (perToken)
{
quantMode += QuantMode::perTokenScaling();
}
if (perGroup)
{
quantMode += QuantMode::perGroupScaling();
}
if (useInt8KvCache)
{
quantMode += int8KvCache();
}
if (useFp8KvCache)
{
quantMode += fp8KvCache();
}
if (useFp8Qdq)
{
quantMode += fp8Qdq();
}
if (useFp8RowWise)
{
quantMode += fp8RowWise();
}
if (useW4a8QServe)
{
quantMode += w4a8QServe();
}
return quantMode;
}
static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false)
{
return fromDescription(true, true, perToken, perChannel);
}
static constexpr QuantMode useQServe(bool perGroup)
{
return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true);
}
static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false)
{
return fromDescription(true, false, false, false, perGroup, useInt4Weights);
}
static QuantMode const fromQuantAlgo(
std::optional<std::string> quantAlgo = std::nullopt, std::optional<std::string> kvCacheQuantAlgo = std::nullopt)
{
QuantMode quantMode{};
if (quantAlgo == "W8A16")
{
quantMode = useWeightOnly(false, false);
}
else if (quantAlgo == "W4A16")
{
quantMode = useWeightOnly(true, false);
}
else if (quantAlgo == "W4A16_AWQ")
{
quantMode = useWeightOnly(true, true);
}
else if (quantAlgo == "W4A8_AWQ")
{
quantMode = useWeightOnly(true, true);
}
else if (quantAlgo == "W4A8_QSERVE_PER_GROUP")
{
quantMode = useQServe(false);
}
else if (quantAlgo == "W4A8_QSERVE_PER_CHANNEL")
{
quantMode = useQServe(true);
}
else if (quantAlgo == "W4A16_GPTQ")
{
quantMode = useWeightOnly(true, true);
}
else if (quantAlgo == "W8A8_SQ_PER_CHANNEL")
{
quantMode = useSmoothQuant(false, true);
}
else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PLUGIN")
{
quantMode = useSmoothQuant(false, false);
}
else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN")
{
quantMode = useSmoothQuant(true, true);
}
else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN")
{
quantMode = useSmoothQuant(false, true);
}
else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN")
{
quantMode = useSmoothQuant(true, false);
}
else if (quantAlgo == "FP8")
{
quantMode = fromDescription(false, false, false, false, false, false, false, false, true);
}
else if (quantAlgo == "FP8_ROWWISE")
{
quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true);
}
if (kvCacheQuantAlgo == "INT8")
{
quantMode += int8KvCache();
}
else if (kvCacheQuantAlgo == "FP8")
{
quantMode += fp8KvCache();
}
return quantMode;
}
constexpr QuantMode operator+(QuantMode const& other) const noexcept
{
return QuantMode(mValue | other.mValue);
}
constexpr QuantMode& operator+=(QuantMode const& other) noexcept
{
return *this = *this + other;
}
constexpr QuantMode operator-(QuantMode const& other) const noexcept
{
return QuantMode(mValue & ~other.mValue);
}
constexpr QuantMode& operator-=(QuantMode const& other) noexcept
{
return *this = *this - other;
}
constexpr bool operator==(QuantMode const& other) const noexcept
{
return mValue == other.mValue;
}
constexpr bool operator!=(QuantMode const& other) const noexcept
{
return !(*this == other);
}
private:
BaseType mValue{0};
};
} // namespace common
} // namespace tensorrt_llm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment