Commit 81a882ad authored by jixx's avatar jixx
Browse files

add tgi2.4.0

parent 9822d7f6
cmake_minimum_required(VERSION 3.20)
if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
find_program(CCACHE_EXECUTABLE "ccache")
if (CCACHE_EXECUTABLE)
message(STATUS "Using ccache")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
endif ()
endif ()
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
endif ()
project(tgi-trtllm-backend VERSION 1.0.0)
set(CMAKE_CXX_STANDARD 20)
include(FetchContent)
include(ExternalProject)
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
#### External dependencies ####
include(cmake/fmt.cmake)
include(cmake/json.cmake)
include(cmake/spdlog.cmake)
include(cmake/trtllm.cmake)
# Let's build TRTLLM as part of CMake
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so
set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE)
# TGI TRTLLM Backend definition
add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h)
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
target_include_directories(tgi_trtllm_backend_impl PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
)
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml)
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt)
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
#### Unit Tests ####
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
message(STATUS "Building tests")
FetchContent_Declare(
Catch2
GIT_REPOSITORY https://github.com/catchorg/Catch2
GIT_TAG v3.6.0
)
FetchContent_MakeAvailable(Catch2)
# add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp)
# target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml)
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
include(CTest)
include(Catch)
# catch_discover_tests(tgi_trtllm_backend_tests)
endif ()
[package]
name = "text-generation-backends-trtllm"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[dependencies]
async-trait = "0.1"
async-stream = "0.3"
clap = { version = "4.5", features = ["derive"] }
cxx = "1.0"
hashbrown = "0.14"
hf-hub = { workspace = true }
log = { version = "0.4", features = [] }
text-generation-router = { path = "../../router" }
tokenizers = { workspace = true }
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.15"
thiserror = "1.0.63"
tracing = "0.1"
tracing-opentelemetry = "0.25"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
[build-dependencies]
cmake = "0.1"
cxx-build = { version = "1.0", features = ["parallel"] }
pkg-config = "0.3"
# Text Generation Inference - TensorRT-LLM Backend Implementation
## Description
This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API
## Simplified Request Sequence
```mermaid
sequenceDiagram
actor User
participant TextGenerationInference.HttpServer
participant TextGenerationInference.TensorRtLlmBackend
participant TextGenerationInference.TensorRtLlmWorkerThread
participant TensorRtLlm.Executor
participant Nvidia.Gpu
User ->> TextGenerationInference.HttpServer: POST /generate
TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters
TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher
activate Nvidia.Gpu
TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier
rect rgb(10, 92, 54)
loop every 100us
rect rgb(15, 81, 50)
alt Acquire lock to query executor
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated
else There are new generated tokens
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted)
rect rgb(11, 110, 79)
alt Generated token is final
TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection
else Generated token is not final
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded
end
end
end
end
deactivate Nvidia.Gpu
end
end
```
use cxx_build::CFG;
use pkg_config;
use std::env;
use std::env::consts::ARCH;
use std::path::{absolute, PathBuf};
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
const CUDA_REQUIRED_VERSION: &str = "12.6";
const MPI_REQUIRED_VERSION: &str = "4.1";
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
// Dependencies
const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"];
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [
("dylib", "tensorrt_llm"),
("static", "tensorrt_llm_executor_static"),
("dylib", "tensorrt_llm_nvrtc_wrapper"),
("dylib", "nvinfer_plugin_tensorrt_llm"),
("dylib", "decoder_attention"),
];
macro_rules! probe {
($name: expr, $version: expr) => {
if let Err(_) = pkg_config::probe_library($name) {
pkg_config::probe_library(&format!("{}-{}", $name, $version))
.expect(&format!("Failed to locate {}", $name));
}
};
}
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
// Build the backend implementation through CMake
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("75-real;80-real;86-real;89-real;90-real");
let mut install_path = PathBuf::from(install_path);
if !install_path.is_absolute() {
install_path = absolute(out_dir).expect("cannot happen").join(install_path);
}
let _ = cmake::Config::new(".")
.uses_cxx11()
.generator("Ninja")
.profile(match is_debug {
true => "Debug",
false => "Release",
})
.env("OPT_LEVEL", opt_level)
.define("CMAKE_INSTALL_PREFIX", &install_path)
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
.build();
// Additional transitive CMake dependencies
let deps_folder = out_dir.join("build").join("_deps");
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
let dep_name = match is_debug {
true => format!("{}d", dependency),
false => String::from(dependency),
};
let dep_path = deps_folder.join(format!("{}-build", dependency));
println!("cargo:rustc-link-search={}", dep_path.display());
println!("cargo:rustc-link-lib=static={}", dep_name);
}
// Emit linkage information from the artifacts we just built
let install_lib_path = install_path.join("lib");
println!(
r"cargo:warning=Adding link search path: {}",
install_lib_path.display()
);
println!(r"cargo:rustc-link-search={}", install_lib_path.display());
(PathBuf::from(install_path), deps_folder)
}
fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
let ndebug = match is_debug {
true => "1",
false => "0",
};
CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs")
.static_flag(true)
.include(deps_folder.join("fmt-src").join("include"))
.include(deps_folder.join("spdlog-src").join("include"))
.include(deps_folder.join("json-src").join("include"))
.include(deps_folder.join("trtllm-src").join("cpp").join("include"))
.include("/usr/local/cuda/include")
.include("/usr/local/tensorrt/include")
.file("src/ffi.cpp")
.std("c++20")
.define("NDEBUG", ndebug)
.compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt");
println!("cargo:rerun-if-changed=cmake/trtllm.cmake");
println!("cargo:rerun-if-changed=cmake/json.cmake");
println!("cargo:rerun-if-changed=cmake/fmt.cmake");
println!("cargo:rerun-if-changed=cmake/spdlog.cmake");
println!("cargo:rerun-if-changed=include/backend.h");
println!("cargo:rerun-if-changed=lib/backend.cpp");
println!("cargo:rerun-if-changed=include/ffi.h");
println!("cargo:rerun-if-changed=src/ffi.cpp");
}
fn main() {
// Misc variables
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let build_profile = env::var("PROFILE").unwrap();
let (is_debug, opt_level) = match build_profile.as_ref() {
"debug" => (true, "0"),
_ => (false, "3"),
};
// Build the backend
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
// Build the FFI layer calling the backend above
build_ffi_layer(&deps_folder, is_debug);
// Emit linkage search path
probe!("ompi", MPI_REQUIRED_VERSION);
// Probe CUDA & co. with pkg-config
CUDA_TRANSITIVE_DEPS.iter().for_each(|name| {
probe!(name, CUDA_REQUIRED_VERSION);
});
// NCCL is slightly trickier because it might not have a pkgconfig installed
let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH);
let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default);
println!(r"cargo:rustc-link-search=native={}", nccl_library_path);
println!("cargo:rustc-link-lib=dylib=nccl");
// TensorRT
let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib");
println!(r"cargo:rustc-link-search=native={}", tensort_library_path);
println!("cargo:rustc-link-lib=dylib=nvinfer");
// TensorRT-LLM
TENSORRT_LLM_TRANSITIVE_DEPS
.iter()
.for_each(|(link_type, name)| {
println!("cargo:rustc-link-lib={}={}", link_type, name);
});
// Backend
BACKEND_DEPS.iter().for_each(|name| {
println!("cargo:rustc-link-lib=static={}", name);
});
}
FetchContent_Declare(
fmt
DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/fmtlib/fmt/archive/refs/tags/11.0.2.tar.gz
)
FetchContent_MakeAvailable(fmt)
fetchcontent_declare(
json
DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
)
fetchcontent_makeavailable(json)
set(SPDLOG_USE_FMT ON)
set(SPDLOG_BUILD_SHARED OFF)
set(SPDLOG_FMT_EXTERNAL ON)
# Define the level at which SPDLOG_ compilation level is defined
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
else ()
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
endif ()
fetchcontent_declare(
spdlog
DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
)
fetchcontent_makeavailable(spdlog)
set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})
set(USE_CXX11_ABI ON)
set(BUILD_PYT OFF)
set(BUILD_PYBIND OFF)
set(BUILD_MICRO_BENCHMARKS OFF)
set(BUILD_BENCHMARKS OFF)
set(BUILD_TESTS OFF)
set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})
message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
set(FAST_BUILD ON)
set(NVTX_DISABLE OFF)
else ()
set(FAST_BUILD OFF)
set(FAST_MATH ON)
set(NVTX_DISABLE ON)
endif ()
fetchcontent_declare(
trtllm
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
GIT_TAG 201135e58aa525af7e523d091d4c9584229524bc
GIT_SHALLOW FALSE
DOWNLOAD_EXTRACT_TIMESTAMP
)
fetchcontent_makeavailable(trtllm)
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here
set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name")
set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}"
CACHE INTERNAL "nvrtc wrapper library path")
# The same Executor Static library
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name")
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path")
//
// Created by Morgan Funtowicz on 6/30/24.
//
#ifndef TGI_TRTLLM_BACKEND_H
#define TGI_TRTLLM_BACKEND_H
#include <array>
#include <cmath>
#include <filesystem>
#include <span>
#include <vector>
#include <nlohmann/json.hpp>
#include <tensorrt_llm/runtime/common.h>
#include <tensorrt_llm/executor/executor.h>
#include <tensorrt_llm/plugins/api/tllmPlugin.h>
using json = nlohmann::json;
namespace tle = tensorrt_llm::executor;
#define CAST_SIZETYPE(x) static_cast<tle::SizeType32>(x)
namespace huggingface::tgi::backends {
using RequestId = tle::IdType;
using TokenId = tle::TokenIdType;
const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
constexpr auto FMT_NOT_ENOUGH_GPUS = FMT_STRING(
"Not enough GPUs to allocate requested model (detected: {:d}, required: {:d})");
constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
"Submitting inference [{}] to the executor ({:d} already in-flight)");
constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
"Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");
/**
* Initialize all the components required by TRTLLM.
* It is required to call this function before attempting to load any engine
*/
void InitializeBackend();
/**
* Initialize logging mechanism
*/
void InitializeLogging();
/**
*
* @param config TensorRT-LLM configuration object
* @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
* @return
*/
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
/**
*
* @param worldSize
* @param workerPath
* @return
*/
tle::ParallelConfig GetParallelConfig(size_t worldSize, std::string workerPath) noexcept;
/**
* Get the sampling configuration from the parameters provided by TGI
* @param topK
* @param topP
* @param temperature
* @param repetition_penalty
* @param frequency_penalty
* @param seed
* @return
*/
tle::SamplingConfig GetSamplingConfig(
uint32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed
) noexcept;
/**
* Attempt to retrieve the
* @param generationConfigPath
* @return
*/
std::optional<std::list<std::vector<TokenId>>>
GetStopWordsFromConfig(const std::filesystem::path &generationConfigPath) noexcept;
/**
*
*/
class TensorRtLlmBackend {
private:
const json config;
tle::Executor executor;
/** Frequently accessed variables cached here **/
uint32_t maxNumTokens;
std::list<std::vector<TokenId>> stopWords;
public:
explicit TensorRtLlmBackend(
const std::filesystem::path &engineFolder,
const std::filesystem::path &executorWorker
);
/**
* Query the executor for the number of token available for pulling
* @return
*/
[[nodiscard]] size_t NumResponsesReady() const;
/**
* Submit a new generation task to the executor
* @param tokens
* @param topK
* @param topP
* @param temperature
* @param repetitionPenalty
* @param frequencyPenalty
* @param seed
* @return Request id related to this generation for reference
*/
[[nodiscard]] RequestId Submit(
const std::vector<TokenId> &tokens,
uint32_t maxNewTokens,
int32_t topK,
float_t topP,
float_t temperature,
float_t repetitionPenalty,
float_t frequencyPenalty,
uint64_t seed
);
[[nodiscard]] std::vector<tle::Response> PullNewTokens();
};
}
#endif //TGI_TRTLLM_BACKEND_H
//
// Created by mfuntowicz on 7/11/24.
//
#ifndef TGI_TRTLLM_BACKEND_FFI_H
#define TGI_TRTLLM_BACKEND_FFI_H
#include <cmath>
#include <cstddef>
#include <memory>
#include "backend.h"
namespace huggingface::tgi::backends {
class TensorRtLlmBackendImpl;
}
// Template to support returning error from TllmException back to Rust in a Result<>
#include <tensorrt_llm/common/tllmException.h>
namespace rust::behavior {
template<typename Try, typename Fail>
static void trycatch(Try &&func, Fail &&fail) noexcept try {
func();
} catch (tensorrt_llm::common::TllmException &e) {
fail(e.what());
}
}
#include "backends/trtllm/src/lib.rs.h"
namespace huggingface::tgi::backends {
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
public:
/***
*
* @param engineFolder
* @param executorWorker
*/
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
/***
*
* @param tokens
* @param maxNewTokens
* @param topK
* @param topP
* @param temperature
* @param repetition_penalty
* @param frequency_penalty
* @param seed
* @return
*/
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
uint64_t
Submit(rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
int32_t topK, float_t topP, float_t temperature,
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
/***
*
* @return
*/
std::unique_ptr<std::vector<GenerationStep>> PullTokens();
};
/***
*
* @param engineFolder
* @return
*/
std::unique_ptr<TensorRtLlmBackendImpl> CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker);
}
#endif //TGI_TRTLLM_BACKEND_FFI_H
//
// Created by mfuntowicz on 7/23/24.
//
#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H
#define TGI_TRTLLM_BACKEND_HARDWARE_H
#include <cstdint>
#include <limits>
#include <fmt/base.h>
#include <spdlog/spdlog.h>
#include <nvml.h>
namespace huggingface::hardware::cuda {
#define AMPERE_SM_MAJOR 8
#define HOPPER_SM_MAJOR 9
/**
* Store information about the version of the CUDA Compute Capabilities detected on the device
*/
struct CudaComputeCapabilities {
int32_t major;
int32_t minor;
[[nodiscard]] constexpr bool IsPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
[[nodiscard]] constexpr bool IsPostHopper() const { return major >= HOPPER_SM_MAJOR; }
};
CudaComputeCapabilities GetCudaComputeCapabilities() {
// Get the compute capabilities of the current hardware
nvmlDevice_t device;
CudaComputeCapabilities capabilities{0, 0};
if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) {
SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0");
if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) {
SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor);
}
}
return capabilities;
}
/**
* Return the number of GPU detected. If no GPU is detected, return size_t::max()
* @return
*/
std::optional<size_t> GetNumDevices() {
uint32_t numGpus = 0;
if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
return std::optional(numGpus);
} else {
return std::nullopt;
}
}
}
#endif //TGI_TRTLLM_BACKEND_HARDWARE_H
#include <cstdlib>
#include <fstream>
#include <fmt/ranges.h>
#include <spdlog/spdlog.h>
#include <nvml.h>
#include "backend.h"
#include "hardware.h"
void huggingface::tgi::backends::InitializeLogging() {
#ifdef NDEBUG
if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
return std::tolower(c);
});
if (log_level == "debug")
spdlog::set_level(spdlog::level::debug);
else
spdlog::set_level(spdlog::level::info);
}
#else
spdlog::set_level(spdlog::level::debug);
#endif
}
void huggingface::tgi::backends::InitializeBackend() {
SPDLOG_INFO("Initializing Backend...");
nvmlInit_v2();
initTrtLlmPlugins();
InitializeLogging();
SPDLOG_INFO("Backend Executor Version: {}", tle::version());
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
if (numGpus.has_value()) {
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
} else {
SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system");
}
}
[[nodiscard]]
tle::ParallelConfig
huggingface::tgi::backends::GetParallelConfig(const size_t worldSize, const std::string workerPath) noexcept {
auto mode = tle::CommunicationMode::kLEADER;
std::optional<tle::OrchestratorConfig> orchestratorConfig = std::nullopt;
if (worldSize > 1) {
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
mode = tle::CommunicationMode::kORCHESTRATOR;
orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, workerPath, nullptr, true);
} else {
SPDLOG_INFO("Detected single engine deployment, using leader mode");
}
return tle::ParallelConfig(tle::CommunicationType::kMPI, mode, std::nullopt, std::nullopt, orchestratorConfig);
}
[[nodiscard]]
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
tle::ExecutorConfig execConfig(/* maxBeamWidth = */ 1);
// Retrieve the compute capabilities to enable some options at runtime
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get<size_t>();
execConfig.setParallelConfig(GetParallelConfig(worldSize, workerPath));
// Define some configuration variables
execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
execConfig.setEnableChunkedContext(computeCapabilities.IsPostAmpere());
execConfig.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION));
return execConfig;
}
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
const uint32_t topK,
const float_t topP,
const float_t temperature,
const float_t repetition_penalty,
const float_t frequency_penalty,
const uint64_t seed) noexcept {
return tle::SamplingConfig(
1, // TGI only use a single beam
topK,
topP,
std::nullopt,
std::nullopt,
std::nullopt,
seed,
temperature,
temperature,
std::nullopt,
repetition_penalty,
std::nullopt,
frequency_penalty
);
}
std::optional<std::list<std::vector<huggingface::tgi::backends::TokenId>>>
huggingface::tgi::backends::GetStopWordsFromConfig(
const std::filesystem::path &generationConfigPath) noexcept {
if (exists(generationConfigPath)) {
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
if (const auto eosTokenIds = generationConfig["/eos_token_id"_json_pointer]; eosTokenIds.is_array()) {
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
std::list<std::vector<huggingface::tgi::backends::TokenId>> stopWords(eosTokenIds.size());
const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
return {tokenIdObj.template get<tle::TokenIdType>()};
};
std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
return stopWords;
} else {
SPDLOG_INFO("Invalid EOS tokens entry found (not an array)");
}
} else {
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
}
return std::nullopt;
}
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
const std::filesystem::path &enginesFolder,
const std::filesystem::path &executorWorker
) :
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string())) {
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get<std::string_view>());
// Ensure we have enough GPUs on the system
const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get<size_t>();
const auto numGpus = huggingface::hardware::cuda::GetNumDevices().value_or(0);
if (numGpus < worldSize) {
SPDLOG_CRITICAL(FMT_NOT_ENOUGH_GPUS, numGpus, worldSize);
// todo : raise exception to catch on rust side
}
// Cache variables
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
// Attempt to discover stopWords from the generation_config.json
const auto generationConfigPath = enginesFolder / "generation_config.json";
stopWords = GetStopWordsFromConfig(generationConfigPath).value_or(std::list<std::vector<TokenId>>());
}
[[nodiscard("Returned number of requests needs to be consumed")]]
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
#ifdef NDEBUG
return executor.getNumResponsesReady();
#else
const auto numResponses = executor.getNumResponsesReady();
if (numResponses > 0) SPDLOG_INFO(FMT_STRING("Num responses ready: {:d}"), numResponses);
return numResponses;
#endif
}
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const std::vector<tle::TokenIdType> &tokens,
const uint32_t maxNewTokens,
const int32_t topK,
const float_t topP,
const float_t temperature,
const float_t repetitionPenalty,
const float_t frequencyPenalty,
const uint64_t seed
) {
const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
#ifndef NDEBUG
{
const auto &iterations = executor.getLatestIterationStats();
const auto &lastIteration = iterations.front();
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
}
#endif
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
// Build the request
auto request = tle::Request{tokens, CAST_SIZETYPE(maxNewTokensChecked), true, sampling, OUTPUT_CONFIG};
request.setStopWords(stopWords);
// Submit to the executor for batching
return executor.enqueueRequest(request);
}
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
return executor.awaitResponses();
}
#!/bin/bash
set -ex
TRT_VER_BASE="10.4.0"
TRT_VER_FULL="${TRT_VER_BASE}.26"
CUDA_VER="12.6"
CUDNN_VER="9.5.0.50-1"
NCCL_VER="2.22.3-1+cuda12.6"
CUBLAS_VER="12.6.3.3-1"
NVRTC_VER="12.6.77-1"
for i in "$@"; do
case $i in
--TRT_VER=?*) TRT_VER="${i#*=}";;
--CUDA_VER=?*) CUDA_VER="${i#*=}";;
--CUDNN_VER=?*) CUDNN_VER="${i#*=}";;
--NCCL_VER=?*) NCCL_VER="${i#*=}";;
--CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";;
*) ;;
esac
shift
done
NVCC_VERSION_OUTPUT=$(nvcc --version)
if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then
echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}."
exit 1
fi
install_ubuntu_requirements() {
apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates
ARCH=$(uname -m)
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${ARCH}/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
rm /etc/apt/sources.list.d/cuda-ubuntu2404-x86_64.list
apt-get update
if [[ $(apt list --installed | grep libcudnn9) ]]; then
apt-get remove --purge -y --allow-change-held-packages libcudnn9*
fi
if [[ $(apt list --installed | grep libnccl) ]]; then
apt-get remove --purge -y --allow-change-held-packages libnccl*
fi
if [[ $(apt list --installed | grep libcublas) ]]; then
apt-get remove --purge -y --allow-change-held-packages libcublas*
fi
if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then
apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev*
fi
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER}
apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER}
apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER}
# NVRTC static library doesn't exist in NGC PyTorch container.
NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER}
apt-get clean
rm -rf /var/lib/apt/lists/*
}
install_centos_requirements() {
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
yum -y update
yum -y install epel-release
yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER}
yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}
yum clean all
}
install_tensorrt() {
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
TRT_CUDA_VERSION="12.6"
if [ -z "$RELEASE_URL_TRT" ];then
ARCH=${TRT_TARGETARCH}
if [ -z "$ARCH" ];then ARCH=$(uname -m);fi
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-24.04" && OS="ubuntu-24.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${TRT_VER_BASE}/tars/TensorRT-${TRT_VER_FULL}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
fi
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
tar -xf /tmp/TensorRT.tar -C /usr/local/
mv /usr/local/TensorRT-${TRT_VER_FULL} /usr/local/tensorrt
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
rm -rf /tmp/TensorRT.tar
}
# Install base packages depending on the base OS
ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
case "$ID" in
debian)
install_ubuntu_requirements
install_tensorrt
;;
ubuntu)
install_ubuntu_requirements
install_tensorrt
;;
centos)
install_centos_requirements
install_tensorrt
;;
*)
echo "Unable to determine OS..."
exit 1
;;
esac
use std::path::PathBuf;
use thiserror::Error;
use text_generation_router::server;
#[derive(Debug, Error)]
pub enum TensorRtLlmBackendError {
#[error("Provided engine folder {0} doesn't exist")]
EngineFolderDoesntExists(PathBuf),
#[error("Provided executorWorker binary path {0} doesn't exist")]
ExecutorWorkerNotFound(PathBuf),
#[error("TensorRT-LLM Runtime error: {0}")]
Runtime(String),
#[error("Tokenizer error: {0}")]
Tokenizer(String),
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}
//
// Created by mfuntowicz on 6/30/24.
//
#pragma once
#include <algorithm>
#include <exception>
#include <filesystem>
#include <functional>
#include <limits>
#include <iterator>
#include <ranges>
#include <vector>
#include <spdlog/spdlog.h>
#include "backends/trtllm/include/ffi.h"
huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
const std::string_view &engineFolder,
const std::string_view &executorWorker
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
rust::Slice<const uint32_t> tokens,
uint32_t maxNewTokens,
int32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed) {
// This will copy all the items from the initial slice
std::vector<int32_t> tokens_(tokens.begin(), tokens.end());
return TensorRtLlmBackend::Submit(
std::move(tokens_), maxNewTokens, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
}
std::unique_ptr<std::vector<huggingface::tgi::backends::GenerationStep>>
huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
const auto responses = TensorRtLlmBackend::PullNewTokens();
auto steps = std::make_unique<std::vector<GenerationStep>>();
steps->reserve(responses.size());
#ifndef NDEBUG
SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size());
#endif
// Transform tle::Response to GenerationStep
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) {
const auto reqId = r.getRequestId();
if (!r.hasError()) {
const auto result = r.getResult();
return GenerationStep{
reqId,
static_cast<uint32_t>(result.outputTokenIds[0][0]),
result.logProbs.value()[0][0],
result.isFinal,
false,
std::string()
};
} else {
return GenerationStep{
reqId,
0,
0.0,
true,
true,
std::move(r.getErrorMsg())
};
}
});
return steps;
}
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
SPDLOG_INFO("Creating TensorRT-LLM Backend");
// Unconditionally call this to initialize and discover TRTLLM plugins
InitializeBackend();
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath));
}
pub use looper::TensorRtLlmBackendV2;
pub mod errors;
mod looper;
mod utils;
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi {
/// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration
#[derive(Debug, Clone)]
pub struct GenerationStep {
request_id: u64,
token_id: u32,
log_prob: f32,
is_final: bool,
has_error: bool,
error_msg: String,
}
unsafe extern "C++" {
include!("backends/trtllm/src/ffi.cpp");
/// Represent an instance of the underlying TensorRT-LLM backend
type TensorRtLlmBackendImpl;
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
///
/// # Arguments
///
/// * `engine_folder`: Path to the folder containing all the TRTLLM engines
/// * `executor_worker`: Path to the TRTLLM executor worker
///
/// returns: <unknown>
///
/// # Examples
///
/// ```
///
/// ```
#[rust_name = "create_tensorrt_llm_backend"]
fn CreateTensorRtLlmBackend(
engine_folder: &str,
executor_worker: &str,
) -> Result<UniquePtr<TensorRtLlmBackendImpl>>;
#[rust_name = "num_responses_ready"]
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
#[rust_name = "submit"]
fn Submit(
self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[u32],
max_new_tokens: u32,
top_k: i32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64,
) -> Result<u64>;
#[rust_name = "pull_tokens"]
fn PullTokens(
self: Pin<&mut TensorRtLlmBackendImpl>,
) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
}
}
use std::hint;
use std::ops::Deref;
use std::path::Path;
use async_trait::async_trait;
use cxx::UniquePtr;
use hashbrown::HashMap;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError;
use tokio::task::{spawn_blocking, JoinHandle};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn};
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
};
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
use crate::utils::first_line;
type InferResult<T> = Result<T, InferError>;
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
struct GenerationContext {
request: ValidGenerateRequest,
start: Option<Instant>,
queued: Instant,
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
}
#[derive(Debug, Copy, Clone)]
struct DecodedToken {
id: u32,
log_prob: f32,
is_final: bool,
}
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
type Error = InferError;
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
if !step.has_error {
Ok(Self {
id: step.token_id,
log_prob: step.log_prob,
is_final: step.is_final,
})
} else {
Err(GenerationError(step.error_msg.clone()))
}
}
}
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
struct DecodedTokenContext {
token: DecodedToken,
start: Option<Instant>,
queued: Instant,
channel: UnboundedSender<InferResult<InferStreamResponse>>,
}
fn executor_status_looper(
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
max_inflight_requests: usize,
mut waiting_requests: UnboundedReceiver<GenerationContext>,
post_processor_sender: UnboundedSender<(u64, InferResult<DecodedTokenContext>)>,
) {
// Track the tuple (request_id, stream) for each request
let mut in_flights =
HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2);
// TODO: Does it need a spin-loop?
'scheduler: loop {
// Is there any request pending to be scheduled?
let awaiting_requests = waiting_requests.len();
for _ in 0..awaiting_requests {
// Retrieve all the requests
if let Some(mut ctx) = waiting_requests.blocking_recv() {
// Submit all the request to the executor and move the context to the in-flight tracker
let request = &ctx.request;
let generation_params = &request.parameters;
let stopping_params = &request.stopping_parameters;
let input_ids = request.input_ids.as_deref();
// Submit to the TensorRT-LLM executor for scheduling
match backend.pin_mut().submit(
&input_ids.unwrap(), // This is checked beforehand in validate()
stopping_params.max_new_tokens,
generation_params.top_k as i32,
generation_params.top_p,
generation_params.temperature,
generation_params.repetition_penalty,
generation_params.frequency_penalty,
generation_params.seed,
) {
Ok(request_id) => {
// Insert the context linked to the generated request id in the tracker
debug!("[in-flight] Added {}", request_id);
ctx.start = Some(Instant::now());
in_flights.insert(request_id, ctx);
}
Err(e) => {
// Return to the caller
let what = e.to_string();
error!(error = what.as_str(), "Failed to schedule request");
let err = Err(InferError::Overloaded(TryAcquireError::NoPermits));
if let Err(_) = ctx.streamer.send(err) {
error!("Failed to send back error to the client");
}
}
};
}
}
if backend.num_responses_ready() > 0 {
match backend.pin_mut().pull_tokens() {
Ok(responses) => {
// Iterate through all the decoded token
for step in responses.deref() {
if let Some(ctx) = in_flights.get(&step.request_id) {
// Remove from tracked requests
let parcel =
DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
token: dt,
start: ctx.start,
queued: ctx.queued,
channel: ctx.streamer.clone(),
});
// Submit the work to p:the post_processor
let posted = post_processor_sender.send((step.request_id, parcel));
if posted.is_err() || step.is_final {
debug!("Removing {}", step.request_id);
let _ = in_flights.remove(&step.request_id);
}
} else {
warn!("Untracked request {}", step.request_id,);
}
}
}
Err(ref err) => {
error!("Failed to get responses from the executor: {}.", err.what());
break 'scheduler;
}
}
}
// Hint the CPU we are spin-locking
hint::spin_loop();
}
}
fn post_processor_looper<const MAX_NUM_TOKENS: usize>(
tokenizer: Tokenizer,
max_inflight_requests: usize,
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
) {
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(max_inflight_requests * 2);
'post_processor: loop {
if decoded_tokens.is_closed() {
warn!("Post processor IPC is closed, loop will exit now.");
break 'post_processor;
}
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
match decoded {
Ok(ctx) => {
states
.entry(request_id)
.and_modify(|s| s.push(*&ctx.token.id))
.or_insert_with(|| {
let mut state = Vec::with_capacity(MAX_NUM_TOKENS);
state.push(*&ctx.token.id);
state
});
let out = match tokenizer.decode(&[ctx.token.id], false) {
Ok(text) => {
let is_special =
tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token {
id: ctx.token.id,
text,
logprob: ctx.token.log_prob,
special: is_special,
};
let out = if !ctx.token.is_final {
InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}
} else {
let tokens = states.remove(&request_id).unwrap();
let text = tokenizer.decode(&tokens, true);
let generated_text = GeneratedText {
text: text.unwrap(),
generated_tokens: tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
};
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text,
start: ctx.start.unwrap(),
queued: ctx.queued,
}
};
Ok(out)
}
Err(err) => Err(GenerationError(err.to_string())),
};
if let Err(_) = ctx.channel.send(out) {
warn!("Failed to send decoded token back to the user")
}
}
Err(_err) => {
todo!("what do we do?")
}
}
}
}
}
fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
engine_folder: P,
executor_worker_path: PP,
) -> Result<(String, String), TensorRtLlmBackendError> {
// Retrieve paths as &str for the backend creation
let engine_folder = engine_folder.as_ref();
let executor_worker_path = executor_worker_path.as_ref();
// Ensure the engine folder exists
if !engine_folder.exists() {
let err = TensorRtLlmBackendError::EngineFolderDoesntExists(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
// Ensure executor worker binary exists
if !executor_worker_path.exists() {
let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
let engine_folder = String::from(
engine_folder
.to_str()
.expect("Failed to convert engine_folder to valid UTF-8"),
);
let executor_worker_path = String::from(
executor_worker_path
.to_str()
.expect("Failed to convert executor_worker_path to valid UTF-8"),
);
Ok((engine_folder, executor_worker_path))
}
unsafe impl Send for TensorRtLlmBackendImpl {}
pub struct TensorRtLlmBackendV2 {
executor_looper: JoinHandle<()>,
post_processor_looper: JoinHandle<()>,
executor: UnboundedSender<GenerationContext>,
}
impl TensorRtLlmBackendV2 {
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
tokenizer: Tokenizer,
engine_folder: P,
executor_worker_path: PP,
max_inflight_requests: usize,
) -> Result<Self, TensorRtLlmBackendError> {
let (engine_folder, executor_worker_path) =
ensure_paths_exist(engine_folder, executor_worker_path)?;
// Allocate the IPC layer to communicate with the backend
let (executor_sender, executor_receiver) = unbounded_channel();
let (post_processor_sender, post_processor_receiver) = unbounded_channel();
// Create the FFI backend
let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
// Executor looper is responsible for scheduling and pulling requests state at regular interval
let executor_looper = spawn_blocking(move || {
executor_status_looper(
backend,
max_inflight_requests,
executor_receiver,
post_processor_sender,
)
});
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
let post_processor_looper = spawn_blocking(move || {
post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver)
});
Ok(TensorRtLlmBackendV2 {
executor_looper,
post_processor_looper,
executor: executor_sender,
})
}
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
if request.input_ids.is_none() {
return Err(ValidationError(UnsupportedModality("No token provided")));
}
if request.top_n_tokens > 1 {
return Err(ValidationError(TopNTokensDisabled));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(ValidationError(Grammar));
}
match request.inputs.len() {
0 => Err(ValidationError(EmptyInput)),
2.. => Err(GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(_) => Ok(()),
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
},
}
}
}
#[async_trait]
impl Backend for TensorRtLlmBackendV2 {
fn schedule(
&self,
inner: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
Self::validate(&inner)?;
// Open-up the stream to send tokens
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
// Send the context to the executor for scheduling
let queued = Instant::now();
match self.executor.send(GenerationContext {
request: inner,
start: None,
queued,
streamer,
}) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError(
"Failed to submit request to the backend".into(),
)),
}
}
async fn health(&self, _: bool) -> bool {
!self.executor_looper.is_finished() & !self.post_processor_looper.is_finished()
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment