-#### 编译环境准备
+## Table of contents
-有两种方式安装准备环境
-##### 方式一:
+ - [Get Started](#get-started)
+ - [Docker](#docker)
+ - [API documentation](#api-documentation)
+ - [Using a private or gated model](#using-a-private-or-gated-model)
+ - [A note on Shared Memory (shm)](#a-note-on-shared-memory-shm)
+ - [Distributed Tracing](#distributed-tracing)
+ - [Architecture](#architecture)
+ - [Local install](#local-install)
+ - [Optimized architectures](#optimized-architectures)
+ - [Run locally](#run-locally)
+ - [Run](#run)
+ - [Quantization](#quantization)
+ - [Develop](#develop)
+ - [Testing](#testing)
-### **TODO**
+Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as:
-##### 方式二:
+- Simple launcher to serve most popular LLMs
+- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
+- Tensor Parallelism for faster inference on multiple GPUs
+- Token streaming using Server-Sent Events (SSE)
+- Continuous batching of incoming requests for increased total throughput
+- [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) compatible with Open AI Chat Completion API
+- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
+- Quantization with :
+ - [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
+ - [GPT-Q](https://arxiv.org/abs/2210.17323)
+ - [EETQ](https://github.com/NetEase-FuXi/EETQ)
+ - [AWQ](https://github.com/casper-hansen/AutoAWQ)
+ - [Marlin](https://github.com/IST-DASLab/marlin)
+ - [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/)
+- [Safetensors](https://github.com/huggingface/safetensors) weight loading
+- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
+- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
+- Stop sequences
+- Log probabilities
+- [Speculation](https://huggingface.co/docs/text-generation-inference/conceptual/speculation) ~2x latency
+- [Guidance/JSON](https://huggingface.co/docs/text-generation-inference/conceptual/guidance). Specify output format to speed up inference and make sure the output is valid according to some specs..
+- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output
+- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance
-基于光源pytorch2.1.0基础镜像环境:镜像下载地址:[https://sourcefind.cn/#/image/dcu/pytorch](https://sourcefind.cn/#/image/dcu/pytorch),根据pytorch2.1.0、python、dtk及系统下载对应的镜像版本。pytorch2.1.0镜像里已经安装了trition,flash-attn
+### Hardware support
+
+- [Nvidia](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference)
+- [AMD](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference) (-rocm)
+- [Inferentia](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference)
+- [Intel GPU](https://github.com/huggingface/text-generation-inference/pull/1475)
+- [Gaudi](https://github.com/huggingface/tgi-gaudi)
+- [Google TPU](https://huggingface.co/docs/optimum-tpu/howto/serving)
+
+
+## Get Started
+
+### Docker
+
+For a detailed starting guide, please see the [Quick Tour](https://huggingface.co/docs/text-generation-inference/quicktour). The easiest way of getting started is using the official Docker container:
+
+```shell
+model=HuggingFaceH4/zephyr-7b-beta
+# share a volume with the Docker container to avoid downloading weights every run
+volume=$PWD/data
+
+docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
+ ghcr.io/huggingface/text-generation-inference:2.4.0 --model-id $model
+```
+
+And then you can make requests like
+
+```bash
+curl 127.0.0.1:8080/generate_stream \
+ -X POST \
+ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
+ -H 'Content-Type: application/json'
+```
+
+You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
+
+```bash
+curl localhost:8080/v1/chat/completions \
+ -X POST \
+ -d '{
+ "model": "tgi",
+ "messages": [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant."
+ },
+ {
+ "role": "user",
+ "content": "What is deep learning?"
+ }
+ ],
+ "stream": true,
+ "max_tokens": 20
+}' \
+ -H 'Content-Type: application/json'
+```
+
+**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
+
+**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.4.0-rocm --model-id $model` instead of the command above.
+
+To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
+```
+text-generation-launcher --help
+```
+
+### API documentation
+
+You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
+The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
+
+### Using a private or gated model
+
+You have the option to utilize the `HF_TOKEN` environment variable for configuring the token employed by
+`text-generation-inference`. This allows you to gain access to protected resources.
+
+For example, if you want to serve the gated Llama V2 model variants:
+
+1. Go to https://huggingface.co/settings/tokens
+2. Copy your cli READ token
+3. Export `HF_TOKEN=`
+
+or with Docker:
+
+```shell
+model=meta-llama/Meta-Llama-3.1-8B-Instruct
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+token=
+
+docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.4.0 --model-id $model
+```
+
+### A note on Shared Memory (shm)
+
+[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
+`PyTorch` to do distributed training/inference. `text-generation-inference` make
+use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
+
+In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
+peer-to-peer using NVLink or PCI is not possible.
+
+To allow the container to use 1G of Shared Memory and support SHM sharing, we add `--shm-size 1g` on the above command.
+
+If you are running `text-generation-inference` inside `Kubernetes`. You can also add Shared Memory to the container by
+creating a volume with:
+
+```yaml
+- name: shm
+ emptyDir:
+ medium: Memory
+ sizeLimit: 1Gi
+```
+
+and mounting it to `/dev/shm`.
+
+Finally, you can also disable SHM sharing by using the `NCCL_SHM_DISABLE=1` environment variable. However, note that
+this will impact performance.
+
+### Distributed Tracing
+
+`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
+by setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be
+overridden with the `--otlp-service-name` argument
+
+### Architecture
+
+
+
+Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi)
+
+### Local install
+
+You can also opt to install `text-generation-inference` locally.
+
+First [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
+Python 3.9, e.g. using `conda`:
-1. 安装Rust
```shell
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
+
+conda create -n text-generation-inference python=3.11
+conda activate text-generation-inference
```
-2. 安装Protoc
+You may also need to install Protoc.
+
+On Linux:
+
```shell
PROTOC_ZIP=protoc-21.12-linux-x86_64.zip
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP
@@ -49,50 +216,77 @@ sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc
sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'
rm -f $PROTOC_ZIP
```
-3. 安装TGI Service
-```bash
-git clone http://developer.hpccube.com/codes/wangkx1/text_generation_server-dcu.git # 根据需要的分支进行切换
-cd text-generation-inference
-pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
-pip install -r pre_requirements.txt
+On MacOS, using Homebrew:
-#安装exllama
-cd server
-make install-exllama #安装exllama kernels
-make install-exllamav2 #安装exllmav2 kernels
+```shell
+brew install protobuf
+```
-cd .. #回到项目根目录
-source $HOME/.cargo/env
-BUILD_EXTENSIONS=True make install #安装text-generation服务
+Then run:
+```shell
+BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
+text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
```
-4. 安装benchmark
-```bash
-cd text-generation-inference
-make install-benchmark
-```
-注意:若安装过程过慢,可以通过如下命令修改默认源提速。
-```bash
+**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:
+
+```shell
+sudo apt-get install libssl-dev gcc -y
```
-另外,`cargo install` 太慢也可以通过在`~/.cargo/config`中添加源来提速。
-## 查看安装的版本号
-```bash
-text-generation-launcher -V #版本号与官方版本同步
+## Optimized architectures
+
+TGI works out of the box to serve optimized models for all modern models. They can be found in [this list](https://huggingface.co/docs/text-generation-inference/supported_models).
+
+Other architectures are supported on a best-effort basis using:
+
+`AutoModelForCausalLM.from_pretrained(, device_map="auto")`
+
+or
+
+`AutoModelForSeq2SeqLM.from_pretrained(, device_map="auto")`
+
+
+
+## Run locally
+
+### Run
+
+```shell
+text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
```
-## 使用前
+### Quantization
-```bash
-export PYTORCH_TUNABLEOP_ENABLED=0
+You can also run pre-quantized weights (AWQ, GPTQ, Marlin) or on-the-fly quantize weights with bitsandbytes, EETQ, fp8, to reduce the VRAM requirement:
+
+```shell
+text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize
```
-## Known Issue
+4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
+
+Read more about quantization in the [Quantization documentation](https://huggingface.co/docs/text-generation-inference/en/conceptual/quantization).
+
+## Develop
-- 无
+```shell
+make server-dev
+make router-dev
+```
+
+## Testing
-## 参考资料
-- [README_ORIGIN](README_ORIGIN.md)
-- [https://github.com/huggingface/text-generation-inference](https://github.com/huggingface/text-generation-inference)
+```shell
+# python
+make python-server-tests
+make python-client-tests
+# or both server and client tests
+make python-tests
+# rust cargo tests
+make rust-tests
+# integration tests
+make integration-tests
+```
diff --git a/README_ORINGIN.md b/README_ORINGIN.md
deleted file mode 100644
index 74616748efa88a47785eaff7c1952370aaad20ea..0000000000000000000000000000000000000000
--- a/README_ORINGIN.md
+++ /dev/null
@@ -1,258 +0,0 @@
-
-
-
-
-
-
-# Text Generation Inference
-
-
-
-
-
-
-
-
-A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
-to power Hugging Chat, the Inference API and Inference Endpoint.
-
-
-
-## Table of contents
-
-- [Get Started](#get-started)
- - [API Documentation](#api-documentation)
- - [Using a private or gated model](#using-a-private-or-gated-model)
- - [A note on Shared Memory](#a-note-on-shared-memory-shm)
- - [Distributed Tracing](#distributed-tracing)
- - [Local Install](#local-install)
- - [CUDA Kernels](#cuda-kernels)
-- [Optimized architectures](#optimized-architectures)
-- [Run Mistral](#run-a-model)
- - [Run](#run)
- - [Quantization](#quantization)
-- [Develop](#develop)
-- [Testing](#testing)
-
-Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as:
-
-- Simple launcher to serve most popular LLMs
-- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
-- Tensor Parallelism for faster inference on multiple GPUs
-- Token streaming using Server-Sent Events (SSE)
-- Continuous batching of incoming requests for increased total throughput
-- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
-- Quantization with :
- - [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- - [GPT-Q](https://arxiv.org/abs/2210.17323)
- - [EETQ](https://github.com/NetEase-FuXi/EETQ)
- - [AWQ](https://github.com/casper-hansen/AutoAWQ)
-- [Safetensors](https://github.com/huggingface/safetensors) weight loading
-- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
-- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
-- Stop sequences
-- Log probabilities
-- [Speculation](https://huggingface.co/docs/text-generation-inference/conceptual/speculation) ~2x latency
-- [Guidance/JSON](https://huggingface.co/docs/text-generation-inference/conceptual/guidance). Specify output format to speed up inference and make sure the output is valid according to some specs..
-- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output
-- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance
-
-### Hardware support
-
-- [Nvidia](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference)
-- [AMD](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference) (-rocm)
-- [Inferentia](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference)
-- [Intel GPU](https://github.com/huggingface/text-generation-inference/pull/1475)
-- [Gaudi](https://github.com/huggingface/tgi-gaudi)
-- [Google TPU](https://huggingface.co/docs/optimum-tpu/howto/serving)
-
-
-## Get Started
-
-### Docker
-
-For a detailed starting guide, please see the [Quick Tour](https://huggingface.co/docs/text-generation-inference/quicktour). The easiest way of getting started is using the official Docker container:
-
-```shell
-model=HuggingFaceH4/zephyr-7b-beta
-volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
-
-docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
-```
-
-And then you can make requests like
-
-```bash
-curl 127.0.0.1:8080/generate_stream \
- -X POST \
- -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
- -H 'Content-Type: application/json'
-```
-
-**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
-
-**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0-rocm --model-id $model` instead of the command above.
-
-To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
-```
-text-generation-launcher --help
-```
-
-### API documentation
-
-You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
-The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
-
-### Using a private or gated model
-
-You have the option to utilize the `HUGGING_FACE_HUB_TOKEN` environment variable for configuring the token employed by
-`text-generation-inference`. This allows you to gain access to protected resources.
-
-For example, if you want to serve the gated Llama V2 model variants:
-
-1. Go to https://huggingface.co/settings/tokens
-2. Copy your cli READ token
-3. Export `HUGGING_FACE_HUB_TOKEN=`
-
-or with Docker:
-
-```shell
-model=meta-llama/Llama-2-7b-chat-hf
-volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
-token=
-
-docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
-```
-
-### A note on Shared Memory (shm)
-
-[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
-`PyTorch` to do distributed training/inference. `text-generation-inference` make
-use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
-
-In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
-peer-to-peer using NVLink or PCI is not possible.
-
-To allow the container to use 1G of Shared Memory and support SHM sharing, we add `--shm-size 1g` on the above command.
-
-If you are running `text-generation-inference` inside `Kubernetes`. You can also add Shared Memory to the container by
-creating a volume with:
-
-```yaml
-- name: shm
- emptyDir:
- medium: Memory
- sizeLimit: 1Gi
-```
-
-and mounting it to `/dev/shm`.
-
-Finally, you can also disable SHM sharing by using the `NCCL_SHM_DISABLE=1` environment variable. However, note that
-this will impact performance.
-
-### Distributed Tracing
-
-`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
-by setting the address to an OTLP collector with the `--otlp-endpoint` argument.
-
-### Architecture
-
-
-
-### Local install
-
-You can also opt to install `text-generation-inference` locally.
-
-First [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
-Python 3.9, e.g. using `conda`:
-
-```shell
-curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
-
-conda create -n text-generation-inference python=3.11
-conda activate text-generation-inference
-```
-
-You may also need to install Protoc.
-
-On Linux:
-
-```shell
-PROTOC_ZIP=protoc-21.12-linux-x86_64.zip
-curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP
-sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc
-sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'
-rm -f $PROTOC_ZIP
-```
-
-On MacOS, using Homebrew:
-
-```shell
-brew install protobuf
-```
-
-Then run:
-
-```shell
-BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
-text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
-```
-
-**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:
-
-```shell
-sudo apt-get install libssl-dev gcc -y
-```
-
-## Optimized architectures
-
-TGI works out of the box to serve optimized models for all modern models. They can be found in [this list](https://huggingface.co/docs/text-generation-inference/supported_models).
-
-Other architectures are supported on a best-effort basis using:
-
-`AutoModelForCausalLM.from_pretrained(, device_map="auto")`
-
-or
-
-`AutoModelForSeq2SeqLM.from_pretrained(, device_map="auto")`
-
-
-
-## Run locally
-
-### Run
-
-```shell
-text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
-```
-
-### Quantization
-
-You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
-
-```shell
-text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize
-```
-
-4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
-
-## Develop
-
-```shell
-make server-dev
-make router-dev
-```
-
-## Testing
-
-```shell
-# python
-make python-server-tests
-make python-client-tests
-# or both server and client tests
-make python-tests
-# rust cargo tests
-make rust-tests
-# integration tests
-make integration-tests
-```
diff --git a/router/client/Cargo.toml b/backends/client/Cargo.toml
similarity index 100%
rename from router/client/Cargo.toml
rename to backends/client/Cargo.toml
diff --git a/router/client/build.rs b/backends/client/build.rs
similarity index 100%
rename from router/client/build.rs
rename to backends/client/build.rs
diff --git a/router/client/src/lib.rs b/backends/client/src/lib.rs
similarity index 100%
rename from router/client/src/lib.rs
rename to backends/client/src/lib.rs
diff --git a/router/client/src/v2/client.rs b/backends/client/src/v2/client.rs
similarity index 100%
rename from router/client/src/v2/client.rs
rename to backends/client/src/v2/client.rs
diff --git a/router/client/src/v2/mod.rs b/backends/client/src/v2/mod.rs
similarity index 100%
rename from router/client/src/v2/mod.rs
rename to backends/client/src/v2/mod.rs
diff --git a/router/client/src/v2/sharded_client.rs b/backends/client/src/v2/sharded_client.rs
similarity index 100%
rename from router/client/src/v2/sharded_client.rs
rename to backends/client/src/v2/sharded_client.rs
diff --git a/router/client/src/v3/client.rs b/backends/client/src/v3/client.rs
similarity index 96%
rename from router/client/src/v3/client.rs
rename to backends/client/src/v3/client.rs
index a996b14fae873c552458a9cc85cb78a0bfb6d541..d43f789e7ca95421266a068e267d65e000f309ea 100644
--- a/router/client/src/v3/client.rs
+++ b/backends/client/src/v3/client.rs
@@ -153,9 +153,13 @@ impl Client {
}),
// We truncate the input on the server side to be sure that it has the correct size
truncate,
+ // Most request will have that
+ add_special_tokens: true,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
+ cache_len: 0,
+ chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
@@ -214,8 +218,13 @@ impl Client {
pub async fn prefill(
&mut self,
batch: Batch,
+ cached_batch: Option,
) -> Result<(Vec, Option, PrefillTimings)> {
- let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
+ let request = tonic::Request::new(PrefillRequest {
+ batch: Some(batch),
+ cached_batch,
+ })
+ .inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
diff --git a/router/client/src/v3/mod.rs b/backends/client/src/v3/mod.rs
similarity index 100%
rename from router/client/src/v3/mod.rs
rename to backends/client/src/v3/mod.rs
diff --git a/router/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs
similarity index 96%
rename from router/client/src/v3/sharded_client.rs
rename to backends/client/src/v3/sharded_client.rs
index ae8a899b38a6b5571e4b61cba157deb73b5ab8c3..854a5895ebab7474b2391dead5078b19e3a05370 100644
--- a/router/client/src/v3/sharded_client.rs
+++ b/backends/client/src/v3/sharded_client.rs
@@ -134,11 +134,12 @@ impl ShardedClient {
pub async fn prefill(
&mut self,
batch: Batch,
+ cached_batch: Option,
) -> Result<(Vec, Option, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
- .map(|client| Box::pin(client.prefill(batch.clone())))
+ .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result, Option, PrefillTimings)>> =
@@ -221,6 +222,7 @@ impl Health for ShardedClient {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
truncate: 10,
+ add_special_tokens: true,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
@@ -244,6 +246,8 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
+ cache_len: 0,
+ chunk_len: None,
adapter_id: None,
};
let batch = Batch {
@@ -253,7 +257,7 @@ impl Health for ShardedClient {
max_tokens: 2,
max_blocks: 1,
};
- self.clone().prefill(batch).await?;
+ self.clone().prefill(batch, None).await?;
Ok(())
}
}
diff --git a/router/grpc-metadata/Cargo.toml b/backends/grpc-metadata/Cargo.toml
similarity index 100%
rename from router/grpc-metadata/Cargo.toml
rename to backends/grpc-metadata/Cargo.toml
diff --git a/router/grpc-metadata/src/lib.rs b/backends/grpc-metadata/src/lib.rs
similarity index 100%
rename from router/grpc-metadata/src/lib.rs
rename to backends/grpc-metadata/src/lib.rs
diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..831372cdf99650041bcc51d2e438ccc2c7893a4f
--- /dev/null
+++ b/backends/trtllm/CMakeLists.txt
@@ -0,0 +1,75 @@
+cmake_minimum_required(VERSION 3.20)
+
+if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
+ find_program(CCACHE_EXECUTABLE "ccache")
+ if (CCACHE_EXECUTABLE)
+ message(STATUS "Using ccache")
+ set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
+ endif ()
+endif ()
+
+if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
+ cmake_policy(SET CMP0135 NEW)
+endif ()
+
+project(tgi-trtllm-backend VERSION 1.0.0)
+set(CMAKE_CXX_STANDARD 20)
+
+include(FetchContent)
+include(ExternalProject)
+
+option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
+option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
+set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
+set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
+set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
+set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
+
+# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
+find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
+
+#### External dependencies ####
+include(cmake/fmt.cmake)
+include(cmake/json.cmake)
+include(cmake/spdlog.cmake)
+include(cmake/trtllm.cmake)
+
+# Let's build TRTLLM as part of CMake
+add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
+
+# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so
+set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE)
+
+# TGI TRTLLM Backend definition
+add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h)
+include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
+target_include_directories(tgi_trtllm_backend_impl PRIVATE
+ $
+ $
+)
+target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
+target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml)
+target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt)
+
+# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
+install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
+install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
+
+#### Unit Tests ####
+if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
+ message(STATUS "Building tests")
+ FetchContent_Declare(
+ Catch2
+ GIT_REPOSITORY https://github.com/catchorg/Catch2
+ GIT_TAG v3.6.0
+ )
+ FetchContent_MakeAvailable(Catch2)
+
+ # add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp)
+ # target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml)
+
+ list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
+ include(CTest)
+ include(Catch)
+ # catch_discover_tests(tgi_trtllm_backend_tests)
+endif ()
diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml
new file mode 100644
index 0000000000000000000000000000000000000000..97ef1a768917160d744ad92a46e88a7a54d492a4
--- /dev/null
+++ b/backends/trtllm/Cargo.toml
@@ -0,0 +1,28 @@
+[package]
+name = "text-generation-backends-trtllm"
+version.workspace = true
+edition.workspace = true
+authors.workspace = true
+homepage.workspace = true
+
+[dependencies]
+async-trait = "0.1"
+async-stream = "0.3"
+clap = { version = "4.5", features = ["derive"] }
+cxx = "1.0"
+hashbrown = "0.14"
+hf-hub = { workspace = true }
+log = { version = "0.4", features = [] }
+text-generation-router = { path = "../../router" }
+tokenizers = { workspace = true }
+tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
+tokio-stream = "0.1.15"
+thiserror = "1.0.63"
+tracing = "0.1"
+tracing-opentelemetry = "0.25"
+tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
+
+[build-dependencies]
+cmake = "0.1"
+cxx-build = { version = "1.0", features = ["parallel"] }
+pkg-config = "0.3"
diff --git a/backends/trtllm/README.md b/backends/trtllm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..94064504d74869c38b70165dc8a3bbf85eafa26d
--- /dev/null
+++ b/backends/trtllm/README.md
@@ -0,0 +1,46 @@
+# Text Generation Inference - TensorRT-LLM Backend Implementation
+
+## Description
+
+This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API
+
+## Simplified Request Sequence
+
+```mermaid
+sequenceDiagram
+ actor User
+ participant TextGenerationInference.HttpServer
+ participant TextGenerationInference.TensorRtLlmBackend
+ participant TextGenerationInference.TensorRtLlmWorkerThread
+ participant TensorRtLlm.Executor
+ participant Nvidia.Gpu
+ User ->> TextGenerationInference.HttpServer: POST /generate
+ TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters
+ TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request
+ TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher
+ activate Nvidia.Gpu
+ TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution
+ TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier
+ rect rgb(10, 92, 54)
+ loop every 100us
+ rect rgb(15, 81, 50)
+ alt Acquire lock to query executor
+ TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated
+ else There are new generated tokens
+ TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens
+ TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted)
+ rect rgb(11, 110, 79)
+ alt Generated token is final
+ TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU
+ TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection
+ else Generated token is not final
+ TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded
+ end
+ end
+ end
+ end
+ deactivate Nvidia.Gpu
+ end
+ end
+
+```
diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs
new file mode 100644
index 0000000000000000000000000000000000000000..985019260b837035907b7e8af5eb20688e85301e
--- /dev/null
+++ b/backends/trtllm/build.rs
@@ -0,0 +1,160 @@
+use cxx_build::CFG;
+use pkg_config;
+use std::env;
+use std::env::consts::ARCH;
+use std::path::{absolute, PathBuf};
+
+const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
+const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
+const CUDA_REQUIRED_VERSION: &str = "12.6";
+const MPI_REQUIRED_VERSION: &str = "4.1";
+const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
+const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
+const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
+
+// Dependencies
+const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"];
+const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
+const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [
+ ("dylib", "tensorrt_llm"),
+ ("static", "tensorrt_llm_executor_static"),
+ ("dylib", "tensorrt_llm_nvrtc_wrapper"),
+ ("dylib", "nvinfer_plugin_tensorrt_llm"),
+ ("dylib", "decoder_attention"),
+];
+
+macro_rules! probe {
+ ($name: expr, $version: expr) => {
+ if let Err(_) = pkg_config::probe_library($name) {
+ pkg_config::probe_library(&format!("{}-{}", $name, $version))
+ .expect(&format!("Failed to locate {}", $name));
+ }
+ };
+}
+
+fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
+ // Build the backend implementation through CMake
+ let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
+ let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
+ let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("75-real;80-real;86-real;89-real;90-real");
+
+ let mut install_path = PathBuf::from(install_path);
+ if !install_path.is_absolute() {
+ install_path = absolute(out_dir).expect("cannot happen").join(install_path);
+ }
+
+ let _ = cmake::Config::new(".")
+ .uses_cxx11()
+ .generator("Ninja")
+ .profile(match is_debug {
+ true => "Debug",
+ false => "Release",
+ })
+ .env("OPT_LEVEL", opt_level)
+ .define("CMAKE_INSTALL_PREFIX", &install_path)
+ .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
+ .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
+ .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
+ .build();
+
+ // Additional transitive CMake dependencies
+ let deps_folder = out_dir.join("build").join("_deps");
+ for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
+ let dep_name = match is_debug {
+ true => format!("{}d", dependency),
+ false => String::from(dependency),
+ };
+ let dep_path = deps_folder.join(format!("{}-build", dependency));
+ println!("cargo:rustc-link-search={}", dep_path.display());
+ println!("cargo:rustc-link-lib=static={}", dep_name);
+ }
+
+ // Emit linkage information from the artifacts we just built
+ let install_lib_path = install_path.join("lib");
+
+ println!(
+ r"cargo:warning=Adding link search path: {}",
+ install_lib_path.display()
+ );
+ println!(r"cargo:rustc-link-search={}", install_lib_path.display());
+
+ (PathBuf::from(install_path), deps_folder)
+}
+
+fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
+ let ndebug = match is_debug {
+ true => "1",
+ false => "0",
+ };
+
+ CFG.include_prefix = "backends/trtllm";
+ cxx_build::bridge("src/lib.rs")
+ .static_flag(true)
+ .include(deps_folder.join("fmt-src").join("include"))
+ .include(deps_folder.join("spdlog-src").join("include"))
+ .include(deps_folder.join("json-src").join("include"))
+ .include(deps_folder.join("trtllm-src").join("cpp").join("include"))
+ .include("/usr/local/cuda/include")
+ .include("/usr/local/tensorrt/include")
+ .file("src/ffi.cpp")
+ .std("c++20")
+ .define("NDEBUG", ndebug)
+ .compile("tgi_trtllm_backend");
+
+ println!("cargo:rerun-if-changed=CMakeLists.txt");
+ println!("cargo:rerun-if-changed=cmake/trtllm.cmake");
+ println!("cargo:rerun-if-changed=cmake/json.cmake");
+ println!("cargo:rerun-if-changed=cmake/fmt.cmake");
+ println!("cargo:rerun-if-changed=cmake/spdlog.cmake");
+ println!("cargo:rerun-if-changed=include/backend.h");
+ println!("cargo:rerun-if-changed=lib/backend.cpp");
+ println!("cargo:rerun-if-changed=include/ffi.h");
+ println!("cargo:rerun-if-changed=src/ffi.cpp");
+}
+
+fn main() {
+ // Misc variables
+ let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
+ let build_profile = env::var("PROFILE").unwrap();
+ let (is_debug, opt_level) = match build_profile.as_ref() {
+ "debug" => (true, "0"),
+ _ => (false, "3"),
+ };
+
+ // Build the backend
+ let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
+
+ // Build the FFI layer calling the backend above
+ build_ffi_layer(&deps_folder, is_debug);
+
+ // Emit linkage search path
+ probe!("ompi", MPI_REQUIRED_VERSION);
+
+ // Probe CUDA & co. with pkg-config
+ CUDA_TRANSITIVE_DEPS.iter().for_each(|name| {
+ probe!(name, CUDA_REQUIRED_VERSION);
+ });
+
+ // NCCL is slightly trickier because it might not have a pkgconfig installed
+ let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH);
+ let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default);
+ println!(r"cargo:rustc-link-search=native={}", nccl_library_path);
+ println!("cargo:rustc-link-lib=dylib=nccl");
+
+ // TensorRT
+ let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib");
+ println!(r"cargo:rustc-link-search=native={}", tensort_library_path);
+ println!("cargo:rustc-link-lib=dylib=nvinfer");
+
+ // TensorRT-LLM
+ TENSORRT_LLM_TRANSITIVE_DEPS
+ .iter()
+ .for_each(|(link_type, name)| {
+ println!("cargo:rustc-link-lib={}={}", link_type, name);
+ });
+
+ // Backend
+ BACKEND_DEPS.iter().for_each(|name| {
+ println!("cargo:rustc-link-lib=static={}", name);
+ });
+}
diff --git a/backends/trtllm/cmake/fmt.cmake b/backends/trtllm/cmake/fmt.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..afd6ea5f0906e31a491d2113d6527cd9d6a9850e
--- /dev/null
+++ b/backends/trtllm/cmake/fmt.cmake
@@ -0,0 +1,6 @@
+FetchContent_Declare(
+ fmt
+ DOWNLOAD_EXTRACT_TIMESTAMP
+ URL https://github.com/fmtlib/fmt/archive/refs/tags/11.0.2.tar.gz
+)
+FetchContent_MakeAvailable(fmt)
diff --git a/backends/trtllm/cmake/json.cmake b/backends/trtllm/cmake/json.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..67eff2fe606708a4561b2f9ad3f19441c561a92d
--- /dev/null
+++ b/backends/trtllm/cmake/json.cmake
@@ -0,0 +1,6 @@
+fetchcontent_declare(
+ json
+ DOWNLOAD_EXTRACT_TIMESTAMP
+ URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
+)
+fetchcontent_makeavailable(json)
diff --git a/backends/trtllm/cmake/spdlog.cmake b/backends/trtllm/cmake/spdlog.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..7f529a7d29e5ab39d9af9b233399f9cc3da1d9da
--- /dev/null
+++ b/backends/trtllm/cmake/spdlog.cmake
@@ -0,0 +1,17 @@
+set(SPDLOG_USE_FMT ON)
+set(SPDLOG_BUILD_SHARED OFF)
+set(SPDLOG_FMT_EXTERNAL ON)
+
+# Define the level at which SPDLOG_ compilation level is defined
+if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
+ add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
+else ()
+ add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
+endif ()
+
+fetchcontent_declare(
+ spdlog
+ DOWNLOAD_EXTRACT_TIMESTAMP
+ URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
+)
+fetchcontent_makeavailable(spdlog)
diff --git a/backends/trtllm/cmake/trtllm.cmake b/backends/trtllm/cmake/trtllm.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..5f1b6c19c01f25e31a5703d75f6f8774b1a7250b
--- /dev/null
+++ b/backends/trtllm/cmake/trtllm.cmake
@@ -0,0 +1,43 @@
+set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
+set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})
+
+set(USE_CXX11_ABI ON)
+set(BUILD_PYT OFF)
+set(BUILD_PYBIND OFF)
+set(BUILD_MICRO_BENCHMARKS OFF)
+set(BUILD_BENCHMARKS OFF)
+set(BUILD_TESTS OFF)
+set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})
+
+message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
+
+if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
+ set(FAST_BUILD ON)
+ set(NVTX_DISABLE OFF)
+else ()
+ set(FAST_BUILD OFF)
+ set(FAST_MATH ON)
+ set(NVTX_DISABLE ON)
+endif ()
+
+fetchcontent_declare(
+ trtllm
+ GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
+ GIT_TAG 201135e58aa525af7e523d091d4c9584229524bc
+ GIT_SHALLOW FALSE
+ DOWNLOAD_EXTRACT_TIMESTAMP
+)
+fetchcontent_makeavailable(trtllm)
+
+message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
+execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
+execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
+
+# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here
+set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name")
+set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}"
+ CACHE INTERNAL "nvrtc wrapper library path")
+
+# The same Executor Static library
+set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name")
+set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path")
diff --git a/server/marlin/marlin_kernels/py.typed b/backends/trtllm/cmake/utils/detect_cuda_arch.cu
similarity index 100%
rename from server/marlin/marlin_kernels/py.typed
rename to backends/trtllm/cmake/utils/detect_cuda_arch.cu
diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h
new file mode 100644
index 0000000000000000000000000000000000000000..d23f6288964d74a6c71e11bc13b3b235a4d592fa
--- /dev/null
+++ b/backends/trtllm/include/backend.h
@@ -0,0 +1,144 @@
+//
+// Created by Morgan Funtowicz on 6/30/24.
+//
+
+#ifndef TGI_TRTLLM_BACKEND_H
+#define TGI_TRTLLM_BACKEND_H
+
+#include
+#include
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+#include
+
+using json = nlohmann::json;
+namespace tle = tensorrt_llm::executor;
+
+
+#define CAST_SIZETYPE(x) static_cast(x)
+
+namespace huggingface::tgi::backends {
+ using RequestId = tle::IdType;
+ using TokenId = tle::TokenIdType;
+
+ const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
+ constexpr auto FMT_NOT_ENOUGH_GPUS = FMT_STRING(
+ "Not enough GPUs to allocate requested model (detected: {:d}, required: {:d})");
+ constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
+ "Submitting inference [{}] to the executor ({:d} already in-flight)");
+ constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
+ "Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");
+
+ /**
+ * Initialize all the components required by TRTLLM.
+ * It is required to call this function before attempting to load any engine
+ */
+ void InitializeBackend();
+
+ /**
+ * Initialize logging mechanism
+ */
+ void InitializeLogging();
+
+
+ /**
+ *
+ * @param config TensorRT-LLM configuration object
+ * @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
+ * @return
+ */
+ tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
+
+ /**
+ *
+ * @param worldSize
+ * @param workerPath
+ * @return
+ */
+ tle::ParallelConfig GetParallelConfig(size_t worldSize, std::string workerPath) noexcept;
+
+ /**
+ * Get the sampling configuration from the parameters provided by TGI
+ * @param topK
+ * @param topP
+ * @param temperature
+ * @param repetition_penalty
+ * @param frequency_penalty
+ * @param seed
+ * @return
+ */
+ tle::SamplingConfig GetSamplingConfig(
+ uint32_t topK,
+ float_t topP,
+ float_t temperature,
+ float_t repetition_penalty,
+ float_t frequency_penalty,
+ uint64_t seed
+ ) noexcept;
+
+ /**
+ * Attempt to retrieve the
+ * @param generationConfigPath
+ * @return
+ */
+ std::optional>>
+ GetStopWordsFromConfig(const std::filesystem::path &generationConfigPath) noexcept;
+
+ /**
+ *
+ */
+ class TensorRtLlmBackend {
+ private:
+ const json config;
+ tle::Executor executor;
+
+ /** Frequently accessed variables cached here **/
+ uint32_t maxNumTokens;
+ std::list> stopWords;
+
+ public:
+ explicit TensorRtLlmBackend(
+ const std::filesystem::path &engineFolder,
+ const std::filesystem::path &executorWorker
+ );
+
+ /**
+ * Query the executor for the number of token available for pulling
+ * @return
+ */
+ [[nodiscard]] size_t NumResponsesReady() const;
+
+ /**
+ * Submit a new generation task to the executor
+ * @param tokens
+ * @param topK
+ * @param topP
+ * @param temperature
+ * @param repetitionPenalty
+ * @param frequencyPenalty
+ * @param seed
+ * @return Request id related to this generation for reference
+ */
+ [[nodiscard]] RequestId Submit(
+ const std::vector &tokens,
+ uint32_t maxNewTokens,
+ int32_t topK,
+ float_t topP,
+ float_t temperature,
+ float_t repetitionPenalty,
+ float_t frequencyPenalty,
+ uint64_t seed
+ );
+
+ [[nodiscard]] std::vector PullNewTokens();
+ };
+}
+
+
+#endif //TGI_TRTLLM_BACKEND_H
diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h
new file mode 100644
index 0000000000000000000000000000000000000000..449bcd4d7398146867628825dfe182419b5666e6
--- /dev/null
+++ b/backends/trtllm/include/ffi.h
@@ -0,0 +1,75 @@
+//
+// Created by mfuntowicz on 7/11/24.
+//
+
+#ifndef TGI_TRTLLM_BACKEND_FFI_H
+#define TGI_TRTLLM_BACKEND_FFI_H
+
+#include
+#include
+#include
+#include "backend.h"
+
+namespace huggingface::tgi::backends {
+ class TensorRtLlmBackendImpl;
+}
+
+// Template to support returning error from TllmException back to Rust in a Result<>
+#include
+
+namespace rust::behavior {
+ template
+ static void trycatch(Try &&func, Fail &&fail) noexcept try {
+ func();
+ } catch (tensorrt_llm::common::TllmException &e) {
+ fail(e.what());
+ }
+}
+
+#include "backends/trtllm/src/lib.rs.h"
+
+namespace huggingface::tgi::backends {
+
+ class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
+ public:
+ /***
+ *
+ * @param engineFolder
+ * @param executorWorker
+ */
+ TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
+
+ /***
+ *
+ * @param tokens
+ * @param maxNewTokens
+ * @param topK
+ * @param topP
+ * @param temperature
+ * @param repetition_penalty
+ * @param frequency_penalty
+ * @param seed
+ * @return
+ */
+ [[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
+ uint64_t
+ Submit(rust::Slice tokens, uint32_t maxNewTokens,
+ int32_t topK, float_t topP, float_t temperature,
+ float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
+
+ /***
+ *
+ * @return
+ */
+ std::unique_ptr> PullTokens();
+ };
+
+ /***
+ *
+ * @param engineFolder
+ * @return
+ */
+ std::unique_ptr CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker);
+}
+
+#endif //TGI_TRTLLM_BACKEND_FFI_H
diff --git a/backends/trtllm/include/hardware.h b/backends/trtllm/include/hardware.h
new file mode 100644
index 0000000000000000000000000000000000000000..9633495f4fd54682da6508936b9cb853e221411d
--- /dev/null
+++ b/backends/trtllm/include/hardware.h
@@ -0,0 +1,59 @@
+//
+// Created by mfuntowicz on 7/23/24.
+//
+
+#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H
+#define TGI_TRTLLM_BACKEND_HARDWARE_H
+
+#include
+#include
+#include
+#include
+#include
+
+namespace huggingface::hardware::cuda {
+
+#define AMPERE_SM_MAJOR 8
+#define HOPPER_SM_MAJOR 9
+
+ /**
+ * Store information about the version of the CUDA Compute Capabilities detected on the device
+ */
+ struct CudaComputeCapabilities {
+ int32_t major;
+ int32_t minor;
+
+ [[nodiscard]] constexpr bool IsPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
+
+ [[nodiscard]] constexpr bool IsPostHopper() const { return major >= HOPPER_SM_MAJOR; }
+ };
+
+ CudaComputeCapabilities GetCudaComputeCapabilities() {
+ // Get the compute capabilities of the current hardware
+ nvmlDevice_t device;
+ CudaComputeCapabilities capabilities{0, 0};
+ if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) {
+ SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0");
+ if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) {
+ SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor);
+ }
+ }
+
+ return capabilities;
+ }
+
+ /**
+ * Return the number of GPU detected. If no GPU is detected, return size_t::max()
+ * @return
+ */
+ std::optional GetNumDevices() {
+ uint32_t numGpus = 0;
+ if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
+ return std::optional(numGpus);
+ } else {
+ return std::nullopt;
+ }
+ }
+}
+
+#endif //TGI_TRTLLM_BACKEND_HARDWARE_H
diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4dd41de0072866d818a3c901d813f0344a6af683
--- /dev/null
+++ b/backends/trtllm/lib/backend.cpp
@@ -0,0 +1,203 @@
+#include
+#include
+
+#include
+#include
+#include
+
+#include "backend.h"
+#include "hardware.h"
+
+
+void huggingface::tgi::backends::InitializeLogging() {
+#ifdef NDEBUG
+ if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
+ std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
+ std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
+ return std::tolower(c);
+ });
+
+ if (log_level == "debug")
+ spdlog::set_level(spdlog::level::debug);
+ else
+ spdlog::set_level(spdlog::level::info);
+ }
+#else
+ spdlog::set_level(spdlog::level::debug);
+#endif
+}
+
+void huggingface::tgi::backends::InitializeBackend() {
+ SPDLOG_INFO("Initializing Backend...");
+ nvmlInit_v2();
+ initTrtLlmPlugins();
+
+ InitializeLogging();
+
+ SPDLOG_INFO("Backend Executor Version: {}", tle::version());
+ const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
+ if (numGpus.has_value()) {
+ SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
+ } else {
+ SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system");
+ }
+}
+
+[[nodiscard]]
+tle::ParallelConfig
+huggingface::tgi::backends::GetParallelConfig(const size_t worldSize, const std::string workerPath) noexcept {
+ auto mode = tle::CommunicationMode::kLEADER;
+ std::optional orchestratorConfig = std::nullopt;
+
+ if (worldSize > 1) {
+ SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
+ mode = tle::CommunicationMode::kORCHESTRATOR;
+ orchestratorConfig = std::make_optional(true, workerPath, nullptr, true);
+ } else {
+ SPDLOG_INFO("Detected single engine deployment, using leader mode");
+ }
+
+ return tle::ParallelConfig(tle::CommunicationType::kMPI, mode, std::nullopt, std::nullopt, orchestratorConfig);
+}
+
+[[nodiscard]]
+tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
+ tle::ExecutorConfig execConfig(/* maxBeamWidth = */ 1);
+
+ // Retrieve the compute capabilities to enable some options at runtime
+ const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
+
+ // Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
+ const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get();
+ execConfig.setParallelConfig(GetParallelConfig(worldSize, workerPath));
+
+ // Define some configuration variables
+ execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
+ execConfig.setEnableChunkedContext(computeCapabilities.IsPostAmpere());
+ execConfig.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION));
+ return execConfig;
+}
+
+tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
+ const uint32_t topK,
+ const float_t topP,
+ const float_t temperature,
+ const float_t repetition_penalty,
+ const float_t frequency_penalty,
+ const uint64_t seed) noexcept {
+
+ return tle::SamplingConfig(
+ 1, // TGI only use a single beam
+ topK,
+ topP,
+ std::nullopt,
+ std::nullopt,
+ std::nullopt,
+ seed,
+ temperature,
+ temperature,
+ std::nullopt,
+ repetition_penalty,
+ std::nullopt,
+ frequency_penalty
+ );
+}
+
+std::optional>>
+huggingface::tgi::backends::GetStopWordsFromConfig(
+ const std::filesystem::path &generationConfigPath) noexcept {
+ if (exists(generationConfigPath)) {
+ const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
+ if (const auto eosTokenIds = generationConfig["/eos_token_id"_json_pointer]; eosTokenIds.is_array()) {
+ SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
+ std::list> stopWords(eosTokenIds.size());
+
+ const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
+ return {tokenIdObj.template get()};
+ };
+
+ std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
+ return stopWords;
+ } else {
+ SPDLOG_INFO("Invalid EOS tokens entry found (not an array)");
+ }
+ } else {
+ SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
+ }
+
+ return std::nullopt;
+}
+
+huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
+ const std::filesystem::path &enginesFolder,
+ const std::filesystem::path &executorWorker
+) :
+ config(json::parse(std::ifstream(enginesFolder / "config.json"))),
+ executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
+ GetExecutorConfig(config, executorWorker.string())) {
+
+ SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get());
+
+ // Ensure we have enough GPUs on the system
+ const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get();
+ const auto numGpus = huggingface::hardware::cuda::GetNumDevices().value_or(0);
+ if (numGpus < worldSize) {
+ SPDLOG_CRITICAL(FMT_NOT_ENOUGH_GPUS, numGpus, worldSize);
+ // todo : raise exception to catch on rust side
+ }
+
+ // Cache variables
+ maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get();
+
+ // Attempt to discover stopWords from the generation_config.json
+ const auto generationConfigPath = enginesFolder / "generation_config.json";
+ stopWords = GetStopWordsFromConfig(generationConfigPath).value_or(std::list>());
+}
+
+[[nodiscard("Returned number of requests needs to be consumed")]]
+size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
+#ifdef NDEBUG
+ return executor.getNumResponsesReady();
+#else
+ const auto numResponses = executor.getNumResponsesReady();
+ if (numResponses > 0) SPDLOG_INFO(FMT_STRING("Num responses ready: {:d}"), numResponses);
+ return numResponses;
+#endif
+}
+
+[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
+tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
+ const std::vector &tokens,
+ const uint32_t maxNewTokens,
+ const int32_t topK,
+ const float_t topP,
+ const float_t temperature,
+ const float_t repetitionPenalty,
+ const float_t frequencyPenalty,
+ const uint64_t seed
+) {
+ const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast(maxNumTokens - tokens.size()));
+#ifndef NDEBUG
+ {
+ const auto &iterations = executor.getLatestIterationStats();
+ const auto &lastIteration = iterations.front();
+
+ SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
+ SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
+ SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
+ }
+#endif
+
+ const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
+
+ // Build the request
+ auto request = tle::Request{tokens, CAST_SIZETYPE(maxNewTokensChecked), true, sampling, OUTPUT_CONFIG};
+ request.setStopWords(stopWords);
+
+ // Submit to the executor for batching
+ return executor.enqueueRequest(request);
+}
+
+std::vector huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
+ return executor.awaitResponses();
+}
diff --git a/backends/trtllm/scripts/install_tensorrt.sh b/backends/trtllm/scripts/install_tensorrt.sh
new file mode 100755
index 0000000000000000000000000000000000000000..4c2dc26b6bfbdf900edfbc178467b4ffddfd3bae
--- /dev/null
+++ b/backends/trtllm/scripts/install_tensorrt.sh
@@ -0,0 +1,113 @@
+#!/bin/bash
+
+set -ex
+
+TRT_VER_BASE="10.4.0"
+TRT_VER_FULL="${TRT_VER_BASE}.26"
+CUDA_VER="12.6"
+CUDNN_VER="9.5.0.50-1"
+NCCL_VER="2.22.3-1+cuda12.6"
+CUBLAS_VER="12.6.3.3-1"
+NVRTC_VER="12.6.77-1"
+
+for i in "$@"; do
+ case $i in
+ --TRT_VER=?*) TRT_VER="${i#*=}";;
+ --CUDA_VER=?*) CUDA_VER="${i#*=}";;
+ --CUDNN_VER=?*) CUDNN_VER="${i#*=}";;
+ --NCCL_VER=?*) NCCL_VER="${i#*=}";;
+ --CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";;
+ *) ;;
+ esac
+ shift
+done
+
+NVCC_VERSION_OUTPUT=$(nvcc --version)
+if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then
+ echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}."
+ exit 1
+fi
+
+install_ubuntu_requirements() {
+ apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates
+ ARCH=$(uname -m)
+ if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
+ if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
+ curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${ARCH}/cuda-keyring_1.1-1_all.deb
+ dpkg -i cuda-keyring_1.1-1_all.deb
+ rm /etc/apt/sources.list.d/cuda-ubuntu2404-x86_64.list
+
+ apt-get update
+ if [[ $(apt list --installed | grep libcudnn9) ]]; then
+ apt-get remove --purge -y --allow-change-held-packages libcudnn9*
+ fi
+ if [[ $(apt list --installed | grep libnccl) ]]; then
+ apt-get remove --purge -y --allow-change-held-packages libnccl*
+ fi
+ if [[ $(apt list --installed | grep libcublas) ]]; then
+ apt-get remove --purge -y --allow-change-held-packages libcublas*
+ fi
+ if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then
+ apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev*
+ fi
+ CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
+ apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER}
+ apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER}
+ apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER}
+ # NVRTC static library doesn't exist in NGC PyTorch container.
+ NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
+ apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER}
+ apt-get clean
+ rm -rf /var/lib/apt/lists/*
+}
+
+install_centos_requirements() {
+ CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
+ yum -y update
+ yum -y install epel-release
+ yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER}
+ yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}
+ yum clean all
+}
+
+install_tensorrt() {
+ #PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
+ #PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
+ TRT_CUDA_VERSION="12.6"
+
+ if [ -z "$RELEASE_URL_TRT" ];then
+ ARCH=${TRT_TARGETARCH}
+ if [ -z "$ARCH" ];then ARCH=$(uname -m);fi
+ if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
+ if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
+ if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
+ if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-24.04" && OS="ubuntu-24.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
+ RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${TRT_VER_BASE}/tars/TensorRT-${TRT_VER_FULL}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
+ fi
+ wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
+ tar -xf /tmp/TensorRT.tar -C /usr/local/
+ mv /usr/local/TensorRT-${TRT_VER_FULL} /usr/local/tensorrt
+ # pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
+ rm -rf /tmp/TensorRT.tar
+}
+
+# Install base packages depending on the base OS
+ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
+case "$ID" in
+ debian)
+ install_ubuntu_requirements
+ install_tensorrt
+ ;;
+ ubuntu)
+ install_ubuntu_requirements
+ install_tensorrt
+ ;;
+ centos)
+ install_centos_requirements
+ install_tensorrt
+ ;;
+ *)
+ echo "Unable to determine OS..."
+ exit 1
+ ;;
+esac
diff --git a/backends/trtllm/src/errors.rs b/backends/trtllm/src/errors.rs
new file mode 100644
index 0000000000000000000000000000000000000000..812fd6e30d8b2ca2b0fad631c4fdf94d23e1b6b8
--- /dev/null
+++ b/backends/trtllm/src/errors.rs
@@ -0,0 +1,22 @@
+use std::path::PathBuf;
+use thiserror::Error;
+
+use text_generation_router::server;
+
+#[derive(Debug, Error)]
+pub enum TensorRtLlmBackendError {
+ #[error("Provided engine folder {0} doesn't exist")]
+ EngineFolderDoesntExists(PathBuf),
+ #[error("Provided executorWorker binary path {0} doesn't exist")]
+ ExecutorWorkerNotFound(PathBuf),
+ #[error("TensorRT-LLM Runtime error: {0}")]
+ Runtime(String),
+ #[error("Tokenizer error: {0}")]
+ Tokenizer(String),
+ #[error("Argument validation error: {0}")]
+ ArgumentValidation(String),
+ #[error("WebServer error: {0}")]
+ WebServer(#[from] server::WebServerError),
+ #[error("Tokio runtime failed to start: {0}")]
+ Tokio(#[from] std::io::Error),
+}
diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..0a92c050f653f8c294f76795bb193d7428bd2633
--- /dev/null
+++ b/backends/trtllm/src/ffi.cpp
@@ -0,0 +1,89 @@
+//
+// Created by mfuntowicz on 6/30/24.
+//
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include "backends/trtllm/include/ffi.h"
+
+
+huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
+ const std::string_view &engineFolder,
+ const std::string_view &executorWorker
+) : TensorRtLlmBackend(engineFolder, executorWorker) {}
+
+
+uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
+ rust::Slice tokens,
+ uint32_t maxNewTokens,
+ int32_t topK,
+ float_t topP,
+ float_t temperature,
+ float_t repetition_penalty,
+ float_t frequency_penalty,
+ uint64_t seed) {
+
+ // This will copy all the items from the initial slice
+ std::vector tokens_(tokens.begin(), tokens.end());
+ return TensorRtLlmBackend::Submit(
+ std::move(tokens_), maxNewTokens, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
+}
+
+std::unique_ptr>
+huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
+ const auto responses = TensorRtLlmBackend::PullNewTokens();
+
+ auto steps = std::make_unique>();
+ steps->reserve(responses.size());
+
+#ifndef NDEBUG
+ SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size());
+#endif
+
+ // Transform tle::Response to GenerationStep
+ std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) {
+ const auto reqId = r.getRequestId();
+ if (!r.hasError()) {
+ const auto result = r.getResult();
+ return GenerationStep{
+ reqId,
+ static_cast(result.outputTokenIds[0][0]),
+ result.logProbs.value()[0][0],
+ result.isFinal,
+ false,
+ std::string()
+ };
+ } else {
+ return GenerationStep{
+ reqId,
+ 0,
+ 0.0,
+ true,
+ true,
+ std::move(r.getErrorMsg())
+ };
+ }
+ });
+
+ return steps;
+}
+
+std::unique_ptr
+huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
+ SPDLOG_INFO("Creating TensorRT-LLM Backend");
+ // Unconditionally call this to initialize and discover TRTLLM plugins
+ InitializeBackend();
+
+ const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
+ const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
+ return std::make_unique(std::move(enginePath), std::move(executorPath));
+}
diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs
new file mode 100644
index 0000000000000000000000000000000000000000..edd8caff154e94269ff921ef93cb4e721131a9ea
--- /dev/null
+++ b/backends/trtllm/src/lib.rs
@@ -0,0 +1,68 @@
+pub use looper::TensorRtLlmBackendV2;
+
+pub mod errors;
+mod looper;
+mod utils;
+
+#[cxx::bridge(namespace = "huggingface::tgi::backends")]
+mod ffi {
+ /// Struct used as shared type between rust and C++ to represent the result
+ /// of a single decoding iteration
+ #[derive(Debug, Clone)]
+ pub struct GenerationStep {
+ request_id: u64,
+ token_id: u32,
+ log_prob: f32,
+ is_final: bool,
+ has_error: bool,
+ error_msg: String,
+ }
+
+ unsafe extern "C++" {
+ include!("backends/trtllm/src/ffi.cpp");
+
+ /// Represent an instance of the underlying TensorRT-LLM backend
+ type TensorRtLlmBackendImpl;
+
+ /// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
+ ///
+ /// # Arguments
+ ///
+ /// * `engine_folder`: Path to the folder containing all the TRTLLM engines
+ /// * `executor_worker`: Path to the TRTLLM executor worker
+ ///
+ /// returns:
+ ///
+ /// # Examples
+ ///
+ /// ```
+ ///
+ /// ```
+ #[rust_name = "create_tensorrt_llm_backend"]
+ fn CreateTensorRtLlmBackend(
+ engine_folder: &str,
+ executor_worker: &str,
+ ) -> Result>;
+
+ #[rust_name = "num_responses_ready"]
+ fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
+
+ #[rust_name = "submit"]
+ fn Submit(
+ self: Pin<&mut TensorRtLlmBackendImpl>,
+ tokens: &[u32],
+ max_new_tokens: u32,
+ top_k: i32,
+ top_p: f32,
+ temperature: f32,
+ repetition_penalty: f32,
+ frequency_penalty: f32,
+ seed: u64,
+ ) -> Result;
+
+ #[rust_name = "pull_tokens"]
+ fn PullTokens(
+ self: Pin<&mut TensorRtLlmBackendImpl>,
+ ) -> Result>>;
+ }
+}
diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs
new file mode 100644
index 0000000000000000000000000000000000000000..e26155c163c910f46349adbe82838345d37289ce
--- /dev/null
+++ b/backends/trtllm/src/looper.rs
@@ -0,0 +1,382 @@
+use std::hint;
+use std::ops::Deref;
+use std::path::Path;
+
+use async_trait::async_trait;
+use cxx::UniquePtr;
+use hashbrown::HashMap;
+use tokenizers::Tokenizer;
+use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
+use tokio::sync::TryAcquireError;
+use tokio::task::{spawn_blocking, JoinHandle};
+use tokio::time::Instant;
+use tokio_stream::wrappers::UnboundedReceiverStream;
+use tracing::{debug, error, warn};
+
+use text_generation_router::infer::InferError::{GenerationError, ValidationError};
+use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
+use text_generation_router::validation::ValidationError::{
+ EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
+};
+use text_generation_router::validation::{Chunk, ValidGenerateRequest};
+use text_generation_router::{FinishReason, Token};
+
+use crate::errors::TensorRtLlmBackendError;
+use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
+use crate::utils::first_line;
+
+type InferResult = Result;
+
+/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
+struct GenerationContext {
+ request: ValidGenerateRequest,
+ start: Option,
+ queued: Instant,
+ streamer: UnboundedSender>,
+}
+
+#[derive(Debug, Copy, Clone)]
+struct DecodedToken {
+ id: u32,
+ log_prob: f32,
+ is_final: bool,
+}
+
+impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
+ type Error = InferError;
+
+ fn try_from(step: &'step GenerationStep) -> Result {
+ if !step.has_error {
+ Ok(Self {
+ id: step.token_id,
+ log_prob: step.log_prob,
+ is_final: step.is_final,
+ })
+ } else {
+ Err(GenerationError(step.error_msg.clone()))
+ }
+ }
+}
+
+/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
+struct DecodedTokenContext {
+ token: DecodedToken,
+ start: Option,
+ queued: Instant,
+ channel: UnboundedSender>,
+}
+
+fn executor_status_looper(
+ mut backend: UniquePtr,
+ max_inflight_requests: usize,
+ mut waiting_requests: UnboundedReceiver,
+ post_processor_sender: UnboundedSender<(u64, InferResult)>,
+) {
+ // Track the tuple (request_id, stream) for each request
+ let mut in_flights =
+ HashMap::::with_capacity(max_inflight_requests * 2);
+
+ // TODO: Does it need a spin-loop?
+ 'scheduler: loop {
+ // Is there any request pending to be scheduled?
+ let awaiting_requests = waiting_requests.len();
+ for _ in 0..awaiting_requests {
+ // Retrieve all the requests
+ if let Some(mut ctx) = waiting_requests.blocking_recv() {
+ // Submit all the request to the executor and move the context to the in-flight tracker
+ let request = &ctx.request;
+ let generation_params = &request.parameters;
+ let stopping_params = &request.stopping_parameters;
+ let input_ids = request.input_ids.as_deref();
+
+ // Submit to the TensorRT-LLM executor for scheduling
+ match backend.pin_mut().submit(
+ &input_ids.unwrap(), // This is checked beforehand in validate()
+ stopping_params.max_new_tokens,
+ generation_params.top_k as i32,
+ generation_params.top_p,
+ generation_params.temperature,
+ generation_params.repetition_penalty,
+ generation_params.frequency_penalty,
+ generation_params.seed,
+ ) {
+ Ok(request_id) => {
+ // Insert the context linked to the generated request id in the tracker
+ debug!("[in-flight] Added {}", request_id);
+ ctx.start = Some(Instant::now());
+ in_flights.insert(request_id, ctx);
+ }
+ Err(e) => {
+ // Return to the caller
+ let what = e.to_string();
+ error!(error = what.as_str(), "Failed to schedule request");
+
+ let err = Err(InferError::Overloaded(TryAcquireError::NoPermits));
+ if let Err(_) = ctx.streamer.send(err) {
+ error!("Failed to send back error to the client");
+ }
+ }
+ };
+ }
+ }
+
+ if backend.num_responses_ready() > 0 {
+ match backend.pin_mut().pull_tokens() {
+ Ok(responses) => {
+ // Iterate through all the decoded token
+ for step in responses.deref() {
+ if let Some(ctx) = in_flights.get(&step.request_id) {
+ // Remove from tracked requests
+ let parcel =
+ DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
+ token: dt,
+ start: ctx.start,
+ queued: ctx.queued,
+ channel: ctx.streamer.clone(),
+ });
+
+ // Submit the work to p:the post_processor
+ let posted = post_processor_sender.send((step.request_id, parcel));
+
+ if posted.is_err() || step.is_final {
+ debug!("Removing {}", step.request_id);
+ let _ = in_flights.remove(&step.request_id);
+ }
+ } else {
+ warn!("Untracked request {}", step.request_id,);
+ }
+ }
+ }
+ Err(ref err) => {
+ error!("Failed to get responses from the executor: {}.", err.what());
+ break 'scheduler;
+ }
+ }
+ }
+
+ // Hint the CPU we are spin-locking
+ hint::spin_loop();
+ }
+}
+
+fn post_processor_looper(
+ tokenizer: Tokenizer,
+ max_inflight_requests: usize,
+ mut decoded_tokens: UnboundedReceiver<(u64, InferResult)>,
+) {
+ let mut states: HashMap> = HashMap::with_capacity(max_inflight_requests * 2);
+
+ 'post_processor: loop {
+ if decoded_tokens.is_closed() {
+ warn!("Post processor IPC is closed, loop will exit now.");
+ break 'post_processor;
+ }
+
+ if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
+ match decoded {
+ Ok(ctx) => {
+ states
+ .entry(request_id)
+ .and_modify(|s| s.push(*&ctx.token.id))
+ .or_insert_with(|| {
+ let mut state = Vec::with_capacity(MAX_NUM_TOKENS);
+ state.push(*&ctx.token.id);
+ state
+ });
+
+ let out = match tokenizer.decode(&[ctx.token.id], false) {
+ Ok(text) => {
+ let is_special =
+ tokenizer.get_added_vocabulary().is_special_token(&text);
+ let token = Token {
+ id: ctx.token.id,
+ text,
+ logprob: ctx.token.log_prob,
+ special: is_special,
+ };
+
+ let out = if !ctx.token.is_final {
+ InferStreamResponse::Intermediate {
+ token,
+ top_tokens: vec![],
+ }
+ } else {
+ let tokens = states.remove(&request_id).unwrap();
+ let text = tokenizer.decode(&tokens, true);
+ let generated_text = GeneratedText {
+ text: text.unwrap(),
+ generated_tokens: tokens.len() as u32,
+ finish_reason: FinishReason::EndOfSequenceToken,
+ seed: None,
+ };
+
+ InferStreamResponse::End {
+ token,
+ top_tokens: vec![],
+ generated_text,
+ start: ctx.start.unwrap(),
+ queued: ctx.queued,
+ }
+ };
+
+ Ok(out)
+ }
+ Err(err) => Err(GenerationError(err.to_string())),
+ };
+
+ if let Err(_) = ctx.channel.send(out) {
+ warn!("Failed to send decoded token back to the user")
+ }
+ }
+ Err(_err) => {
+ todo!("what do we do?")
+ }
+ }
+ }
+ }
+}
+
+fn ensure_paths_exist, PP: AsRef>(
+ engine_folder: P,
+ executor_worker_path: PP,
+) -> Result<(String, String), TensorRtLlmBackendError> {
+ // Retrieve paths as &str for the backend creation
+ let engine_folder = engine_folder.as_ref();
+ let executor_worker_path = executor_worker_path.as_ref();
+
+ // Ensure the engine folder exists
+ if !engine_folder.exists() {
+ let err = TensorRtLlmBackendError::EngineFolderDoesntExists(engine_folder.to_path_buf());
+
+ error!("Path validation failed: {}", err,);
+ return Err(err);
+ }
+
+ // Ensure executor worker binary exists
+ if !executor_worker_path.exists() {
+ let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf());
+
+ error!("Path validation failed: {}", err,);
+ return Err(err);
+ }
+
+ let engine_folder = String::from(
+ engine_folder
+ .to_str()
+ .expect("Failed to convert engine_folder to valid UTF-8"),
+ );
+
+ let executor_worker_path = String::from(
+ executor_worker_path
+ .to_str()
+ .expect("Failed to convert executor_worker_path to valid UTF-8"),
+ );
+
+ Ok((engine_folder, executor_worker_path))
+}
+
+unsafe impl Send for TensorRtLlmBackendImpl {}
+
+pub struct TensorRtLlmBackendV2 {
+ executor_looper: JoinHandle<()>,
+ post_processor_looper: JoinHandle<()>,
+ executor: UnboundedSender,
+}
+
+impl TensorRtLlmBackendV2 {
+ pub fn new + Send, PP: AsRef + Send>(
+ tokenizer: Tokenizer,
+ engine_folder: P,
+ executor_worker_path: PP,
+ max_inflight_requests: usize,
+ ) -> Result {
+ let (engine_folder, executor_worker_path) =
+ ensure_paths_exist(engine_folder, executor_worker_path)?;
+
+ // Allocate the IPC layer to communicate with the backend
+ let (executor_sender, executor_receiver) = unbounded_channel();
+ let (post_processor_sender, post_processor_receiver) = unbounded_channel();
+
+ // Create the FFI backend
+ let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
+ .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
+
+ // Executor looper is responsible for scheduling and pulling requests state at regular interval
+ let executor_looper = spawn_blocking(move || {
+ executor_status_looper(
+ backend,
+ max_inflight_requests,
+ executor_receiver,
+ post_processor_sender,
+ )
+ });
+
+ // Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
+ let post_processor_looper = spawn_blocking(move || {
+ post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver)
+ });
+
+ Ok(TensorRtLlmBackendV2 {
+ executor_looper,
+ post_processor_looper,
+ executor: executor_sender,
+ })
+ }
+
+ fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
+ if request.input_ids.is_none() {
+ return Err(ValidationError(UnsupportedModality("No token provided")));
+ }
+
+ if request.top_n_tokens > 1 {
+ return Err(ValidationError(TopNTokensDisabled));
+ }
+
+ // TODO: Is it really needed? How can it be validated before?
+ if request.parameters.grammar.is_some() {
+ return Err(ValidationError(Grammar));
+ }
+
+ match request.inputs.len() {
+ 0 => Err(ValidationError(EmptyInput)),
+ 2.. => Err(GenerationError(
+ "TensorRT-LLM backend don't support multi-chunk".into(),
+ )),
+ 1 => match request.inputs.first().expect("Single item-chunk") {
+ Chunk::Text(_) => Ok(()),
+ Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
+ },
+ }
+ }
+}
+
+#[async_trait]
+impl Backend for TensorRtLlmBackendV2 {
+ fn schedule(
+ &self,
+ inner: ValidGenerateRequest,
+ ) -> Result>, InferError> {
+ Self::validate(&inner)?;
+
+ // Open-up the stream to send tokens
+ let (streamer, receiver) = unbounded_channel::>();
+
+ // Send the context to the executor for scheduling
+ let queued = Instant::now();
+ match self.executor.send(GenerationContext {
+ request: inner,
+ start: None,
+ queued,
+ streamer,
+ }) {
+ Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
+ Err(_) => Err(GenerationError(
+ "Failed to submit request to the backend".into(),
+ )),
+ }
+ }
+
+ async fn health(&self, _: bool) -> bool {
+ !self.executor_looper.is_finished() & !self.post_processor_looper.is_finished()
+ }
+}
diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs
new file mode 100644
index 0000000000000000000000000000000000000000..6a247fc1d5265b8b7b240e13169948e9be34eb0b
--- /dev/null
+++ b/backends/trtllm/src/main.rs
@@ -0,0 +1,302 @@
+use std::path::{Path, PathBuf};
+
+use clap::Parser;
+use hf_hub::api::tokio::{Api, ApiBuilder};
+use hf_hub::{Cache, Repo, RepoType};
+use tokenizers::Tokenizer;
+use tracing::info;
+
+use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
+use text_generation_backends_trtllm::TensorRtLlmBackendV2;
+use text_generation_router::server::get_base_tokenizer;
+use text_generation_router::usage_stats::UsageStatsLevel;
+use text_generation_router::{server, HubTokenizerConfig};
+
+/// App Configuration
+#[derive(Parser, Debug)]
+#[clap(author, version, about, long_about = None)]
+struct Args {
+ #[clap(default_value = "128", long, env)]
+ max_concurrent_requests: usize,
+ #[clap(default_value = "2", long, env)]
+ max_best_of: usize,
+ #[clap(default_value = "4", long, env)]
+ max_stop_sequences: usize,
+ #[clap(default_value = "5", long, env)]
+ max_top_n_tokens: u32,
+ #[clap(default_value = "1024", long, env)]
+ max_input_tokens: usize,
+ #[clap(default_value = "2048", long, env)]
+ max_total_tokens: usize,
+ #[clap(default_value = "4096", long, env)]
+ max_batch_prefill_tokens: u32,
+ #[clap(long, env)]
+ max_batch_total_tokens: Option,
+ #[clap(default_value = "0.0.0.0", long, env)]
+ hostname: String,
+ #[clap(default_value = "3000", long, short, env)]
+ port: u16,
+ #[clap(long, env, required = true)]
+ tokenizer_name: String,
+ #[clap(long, env)]
+ tokenizer_config_path: Option,
+ #[clap(long, env)]
+ revision: Option,
+ #[clap(long, env)]
+ model_id: String,
+ #[clap(default_value = "2", long, env)]
+ validation_workers: usize,
+ #[clap(long, env)]
+ json_output: bool,
+ #[clap(long, env)]
+ otlp_endpoint: Option,
+ #[clap(default_value = "text-generation-inference.router", long, env)]
+ otlp_service_name: String,
+ #[clap(long, env)]
+ cors_allow_origin: Option>,
+ #[clap(default_value = "4", long, env)]
+ max_client_batch_size: usize,
+ #[clap(long, env)]
+ auth_token: Option,
+ #[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
+ executor_worker: PathBuf,
+ #[clap(default_value = "on", long, env)]
+ usage_stats: usage_stats::UsageStatsLevel,
+}
+
+async fn get_tokenizer(
+ tokenizer_name: &str,
+ tokenizer_config_path: Option<&str>,
+ revision: Option<&str>,
+) -> Option {
+ // Parse Huggingface hub token
+ let authorization_token = std::env::var("HF_TOKEN")
+ .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
+ .ok();
+
+ // Tokenizer instance
+ let local_path = Path::new(tokenizer_name);
+
+ // Shared API builder initialization
+ let api_builder = || {
+ let mut builder = ApiBuilder::new()
+ .with_progress(false)
+ .with_token(authorization_token);
+
+ if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
+ builder = builder.with_cache_dir(cache_dir.into());
+ }
+
+ builder
+ };
+
+ // Decide if we need to use the API based on the revision and local path
+ let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
+
+ // Initialize API if needed
+ #[derive(Clone)]
+ enum Type {
+ Api(Api),
+ Cache(Cache),
+ None,
+ }
+ let api = if use_api {
+ if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
+ let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
+ .map_err(|_| ())
+ .map(|cache_dir| Cache::new(cache_dir.into()))
+ .unwrap_or_else(|_| Cache::default());
+ tracing::warn!("Offline mode active using cache defaults");
+ Type::Cache(cache)
+ } else {
+ tracing::info!("Using the Hugging Face API");
+ match api_builder().build() {
+ Ok(api) => Type::Api(api),
+ Err(_) => {
+ tracing::warn!("Unable to build the Hugging Face API");
+ Type::None
+ }
+ }
+ }
+ } else {
+ Type::None
+ };
+
+ // Load tokenizer and model info
+ let (
+ tokenizer_filename,
+ _config_filename,
+ tokenizer_config_filename,
+ _preprocessor_config_filename,
+ _processor_config_filename,
+ ) = match api {
+ Type::None => (
+ Some(local_path.join("tokenizer.json")),
+ Some(local_path.join("config.json")),
+ Some(local_path.join("tokenizer_config.json")),
+ Some(local_path.join("preprocessor_config.json")),
+ Some(local_path.join("processor_config.json")),
+ ),
+ Type::Api(api) => {
+ let api_repo = api.repo(Repo::with_revision(
+ tokenizer_name.to_string(),
+ RepoType::Model,
+ revision.unwrap_or_else(|| "main").to_string(),
+ ));
+
+ let tokenizer_filename = match api_repo.get("tokenizer.json").await {
+ Ok(tokenizer_filename) => Some(tokenizer_filename),
+ Err(_) => get_base_tokenizer(&api, &api_repo).await,
+ };
+ let config_filename = api_repo.get("config.json").await.ok();
+ let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
+ let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
+ let processor_config_filename = api_repo.get("processor_config.json").await.ok();
+
+ (
+ tokenizer_filename,
+ config_filename,
+ tokenizer_config_filename,
+ preprocessor_config_filename,
+ processor_config_filename,
+ )
+ }
+ Type::Cache(cache) => {
+ let repo = cache.repo(Repo::with_revision(
+ tokenizer_name.to_string(),
+ RepoType::Model,
+ revision.clone().unwrap_or_else(|| "main").to_string(),
+ ));
+ (
+ repo.get("tokenizer.json"),
+ repo.get("config.json"),
+ repo.get("tokenizer_config.json"),
+ repo.get("preprocessor_config.json"),
+ repo.get("processor_config.json"),
+ )
+ }
+ };
+
+ // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
+ let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path
+ {
+ HubTokenizerConfig::from_file(filename)
+ } else {
+ tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
+ };
+
+ tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
+}
+
+#[tokio::main]
+async fn main() -> Result<(), TensorRtLlmBackendError> {
+ // Get args
+ let args = Args::parse();
+ // Pattern match configuration
+ let Args {
+ max_concurrent_requests,
+ max_best_of,
+ max_stop_sequences,
+ max_top_n_tokens,
+ max_input_tokens,
+ max_total_tokens,
+ max_batch_prefill_tokens,
+ max_batch_total_tokens,
+ hostname,
+ port,
+ tokenizer_name,
+ tokenizer_config_path,
+ revision,
+ model_id,
+ validation_workers,
+ json_output,
+ otlp_endpoint,
+ otlp_service_name,
+ cors_allow_origin,
+ max_client_batch_size,
+ auth_token,
+ executor_worker,
+ usage_stats,
+ } = args;
+
+ // Launch Tokio runtime
+ text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
+
+ // Validate args
+ if max_input_tokens >= max_total_tokens {
+ return Err(TensorRtLlmBackendError::ArgumentValidation(
+ "`max_input_tokens` must be < `max_total_tokens`".to_string(),
+ ));
+ }
+ if max_input_tokens as u32 > max_batch_prefill_tokens {
+ return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
+ }
+
+ if validation_workers == 0 {
+ return Err(TensorRtLlmBackendError::ArgumentValidation(
+ "`validation_workers` must be > 0".to_string(),
+ ));
+ }
+
+ if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
+ if max_batch_prefill_tokens > *max_batch_total_tokens {
+ return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
+ }
+ if max_total_tokens as u32 > *max_batch_total_tokens {
+ return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
+ }
+ }
+
+ if !executor_worker.exists() {
+ return Err(TensorRtLlmBackendError::ArgumentValidation(format!(
+ "`executor_work` specified path doesn't exists: {}",
+ executor_worker.display()
+ )));
+ }
+
+ // Create the backend
+ let tokenizer = get_tokenizer(
+ &tokenizer_name,
+ tokenizer_config_path.as_deref(),
+ revision.as_deref(),
+ )
+ .await
+ .expect("Failed to retrieve tokenizer implementation");
+
+ info!("Successfully retrieved tokenizer {}", &tokenizer_name);
+ let backend = TensorRtLlmBackendV2::new(
+ tokenizer,
+ model_id,
+ executor_worker,
+ max_concurrent_requests,
+ )?;
+
+ info!("Successfully created backend");
+
+ // Run server
+ server::run(
+ backend,
+ max_concurrent_requests,
+ max_best_of,
+ max_stop_sequences,
+ max_top_n_tokens,
+ max_input_tokens,
+ max_total_tokens,
+ validation_workers,
+ auth_token,
+ tokenizer_name,
+ tokenizer_config_path,
+ revision,
+ hostname,
+ port,
+ cors_allow_origin,
+ false,
+ None,
+ None,
+ true,
+ max_client_batch_size,
+ usage_stats,
+ )
+ .await?;
+ Ok(())
+}
diff --git a/backends/trtllm/src/utils.rs b/backends/trtllm/src/utils.rs
new file mode 100644
index 0000000000000000000000000000000000000000..4dedb0078632b43a5f3446fbb141a91cfc003867
--- /dev/null
+++ b/backends/trtllm/src/utils.rs
@@ -0,0 +1,22 @@
+///
+/// Extract the first line of the provided string reference.
+/// If there is no lines in the buffer, it returns a string
+/// which content is defined by the content of `fail`
+/// # Arguments
+///
+/// * `s`: The string buffer to extract the first-line from
+/// * `fail`: A string content which is returned if no lines are
+/// present in `s`
+///
+/// returns: String
+///
+/// # Examples
+///
+/// ```
+/// let s = "My name is Morgan.\n I'm working at Hugging Face.";
+/// first_line(s, "No line in string");
+/// ```
+#[inline]
+pub(crate) fn first_line(s: &str, fail: &str) -> String {
+ s.lines().next().unwrap_or(fail).to_string()
+}
diff --git a/backends/trtllm/tests/infer_test.cpp b/backends/trtllm/tests/infer_test.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..8520065a759e22cf87908450d4a5f3a9526e7809
--- /dev/null
+++ b/backends/trtllm/tests/infer_test.cpp
@@ -0,0 +1,14 @@
+//
+// Created by mfuntowicz on 7/2/24.
+//
+#include
+#include
+#include "../include/backend.h"
+
+TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") {
+ const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/");
+ const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker");
+
+ spdlog::info("Loading config from: {}", absolute(engines).string());
+ huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor);
+}
diff --git a/backends/v2/Cargo.toml b/backends/v2/Cargo.toml
new file mode 100644
index 0000000000000000000000000000000000000000..4d32474e77f65d7fe0913e8177a1a13e739bdfc6
--- /dev/null
+++ b/backends/v2/Cargo.toml
@@ -0,0 +1,75 @@
+[package]
+name = "text-generation-router-v2"
+description = "Text Generation Webserver"
+version.workspace = true
+edition.workspace = true
+authors.workspace = true
+homepage.workspace = true
+
+[lib]
+path = "src/lib.rs"
+
+[[bin]]
+name = "text-generation-router-v2"
+path = "src/main.rs"
+
+[dependencies]
+async-trait = "0.1.74"
+async-stream = "0.3.5"
+axum = { version = "0.7", features = ["json"] }
+axum-tracing-opentelemetry = "0.16"
+text-generation-router = { path = "../../router" }
+clap = { version = "4.4.5", features = ["derive", "env"] }
+grpc-metadata = { path = "../grpc-metadata" }
+futures = "0.3.28"
+hf-hub = { workspace = true }
+jsonschema = { version = "0.17.1", features = ["draft202012"] }
+metrics = { workspace = true }
+metrics-exporter-prometheus = { workspace = true }
+nohash-hasher = "0.2.0"
+opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
+opentelemetry-otlp = "0.13.0"
+rand = "0.8.5"
+reqwest = { version = "0.11.20", features = [] }
+serde = "1.0.188"
+serde_json = "1.0.107"
+slotmap = "1.0.7"
+thiserror = "1.0.48"
+tokenizers = { workspace = true }
+tokio = { version = "1.32.0", features = [
+ "rt",
+ "rt-multi-thread",
+ "parking_lot",
+ "signal",
+ "sync",
+] }
+tokio-stream = "0.1.14"
+tower-http = { version = "0.5.1", features = ["cors"] }
+tracing = "0.1.37"
+tracing-opentelemetry = "0.21.0"
+tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
+utoipa = { version = "4.2.0", features = ["axum_extras"] }
+utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
+init-tracing-opentelemetry = { version = "0.14.1", features = [
+ "opentelemetry-otlp",
+] }
+minijinja = { workspace = true }
+minijinja-contrib = { workspace = true }
+futures-util = "0.3.30"
+regex = "1.10.3"
+once_cell = "1.19.0"
+image = "0.25.1"
+base64 = { workspace = true }
+prost = "^0.12"
+tonic = "^0.10"
+tower = "^0.4"
+
+[build-dependencies]
+tonic-build = "0.10.1"
+prost-build = "0.12.1"
+
+[features]
+default = ["ngrok"]
+ngrok = ["text-generation-router/ngrok"]
+google = ["text-generation-router/google"]
+kserve = ["text-generation-router/kserve"]
diff --git a/backends/v2/build.rs b/backends/v2/build.rs
new file mode 100644
index 0000000000000000000000000000000000000000..f1d85dc70ed09d9d47f5890da215b841834e6824
--- /dev/null
+++ b/backends/v2/build.rs
@@ -0,0 +1,19 @@
+use std::fs;
+
+fn main() -> Result<(), Box> {
+ println!("cargo:rerun-if-changed=../../proto/");
+
+ fs::create_dir_all("src/client/pb").unwrap_or(());
+ let mut config = prost_build::Config::new();
+ config.protoc_arg("--experimental_allow_proto3_optional");
+
+ tonic_build::configure()
+ .build_client(true)
+ .build_server(false)
+ .out_dir("src/client/pb")
+ .include_file("mod.rs")
+ .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
+ .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
+
+ Ok(())
+}
diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs
new file mode 100644
index 0000000000000000000000000000000000000000..bc264138d28636abda4e5fe02ec956abf14fd8ba
--- /dev/null
+++ b/backends/v2/src/backend.rs
@@ -0,0 +1,502 @@
+use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
+/// Batching and inference logic
+use crate::queue::{Entry, Queue};
+use async_trait::async_trait;
+use nohash_hasher::IntMap;
+use std::sync::Arc;
+use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
+use text_generation_router::validation::ValidGenerateRequest;
+use text_generation_router::{FinishReason, PrefillToken, Token};
+use tokio::sync::mpsc::error::SendError;
+use tokio::sync::{mpsc, Notify};
+use tokio::time::Instant;
+use tokio_stream::wrappers::UnboundedReceiverStream;
+use tracing::{info_span, instrument, Instrument, Span};
+
+pub struct BackendV2 {
+ /// Request queue
+ queue: Queue,
+ /// Notify batcher on queue appends
+ batching_task_notifier: Arc,
+ /// Client clone, used for health checks to skip the queue
+ client: ShardedClient,
+}
+
+impl BackendV2 {
+ #[allow(clippy::too_many_arguments)]
+ pub(crate) fn new(
+ client: ShardedClient,
+ waiting_served_ratio: f32,
+ max_batch_prefill_tokens: u32,
+ max_batch_total_tokens: u32,
+ max_waiting_tokens: usize,
+ max_batch_size: Option,
+ requires_padding: bool,
+ window_size: Option,
+ speculate: u32,
+ ) -> Self {
+ // Infer shared state
+ let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
+ let block_size = match attention.as_str() {
+ "flashinfer" => 1,
+ "flashdecoding" => 256,
+ "paged" => 16,
+ _ => unreachable!(),
+ };
+
+ let queue = Queue::new(requires_padding, block_size, window_size, speculate);
+ let batching_task_notifier = Arc::new(Notify::new());
+
+ // Spawn batching background task that contains all the inference logic
+ tokio::spawn(batching_task(
+ client.clone(),
+ waiting_served_ratio,
+ max_batch_prefill_tokens,
+ max_batch_total_tokens,
+ max_waiting_tokens,
+ max_batch_size,
+ queue.clone(),
+ batching_task_notifier.clone(),
+ ));
+
+ Self {
+ queue,
+ batching_task_notifier,
+ client,
+ }
+ }
+}
+
+#[async_trait]
+impl Backend for BackendV2 {
+ #[instrument(skip_all)]
+ fn schedule(
+ &self,
+ request: ValidGenerateRequest,
+ ) -> Result>, InferError> {
+ // MPSC channel to communicate with the background batching task
+ let (response_tx, response_rx) = mpsc::unbounded_channel();
+
+ // Append the request to the queue
+ self.queue.append(Entry {
+ request,
+ response_tx,
+ span: Span::current(),
+ temp_span: None,
+ queue_time: Instant::now(),
+ batch_time: None,
+ });
+
+ // Notify the background task that we have a new entry in the queue that needs
+ // to be batched
+ self.batching_task_notifier.notify_one();
+
+ // Return stream
+ Ok(UnboundedReceiverStream::new(response_rx))
+ }
+
+ async fn health(&self, current_health: bool) -> bool {
+ if current_health {
+ // Generation is healthy, we only check that the shards can allocate on device
+ self.client.device_health().await
+ } else {
+ self.client.model_health().await
+ }
+ .is_ok()
+ }
+}
+
+/// Batching logic
+/// Will be launched in a background Tokio task
+///
+/// Batches requests and sends them to the inference server
+#[allow(clippy::too_many_arguments)]
+pub(crate) async fn batching_task(
+ mut client: ShardedClient,
+ waiting_served_ratio: f32,
+ max_batch_prefill_tokens: u32,
+ max_batch_total_tokens: u32,
+ max_waiting_tokens: usize,
+ max_batch_size: Option,
+ queue: Queue,
+ notifier: Arc,
+) {
+ // Infinite loop
+ loop {
+ // Wait for a notification from the Infer struct
+ notifier.notified().await;
+
+ // Get the next batch from the queue
+ // This batch might be smaller than the maximum batch size if there are not enough requests
+ // waiting in the queue
+ while let Some((mut entries, batch, span)) = queue
+ .next_batch(
+ None,
+ max_batch_size,
+ max_batch_prefill_tokens,
+ max_batch_total_tokens,
+ )
+ .await
+ {
+ let mut cached_batch = prefill(&mut client, batch, &mut entries)
+ .instrument(span)
+ .await;
+ let mut waiting_tokens = 1;
+
+ // We loop until we do not receive any cached batch from the inference server (== until
+ // all requests have met their stopping criteria)
+ while let Some(batch) = cached_batch {
+ // Get current batch info
+ let batch_size = batch.size;
+ let batch_max_tokens = batch.max_tokens;
+ let mut batches = vec![batch];
+ metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
+ metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
+
+ let min_size = if waiting_tokens >= max_waiting_tokens {
+ // If we didn't onboard any new requests since >= max_waiting_tokens, we try
+ // to add a new batch even though its size might be small
+ None
+ } else {
+ // Minimum batch size
+ Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
+ };
+
+ let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
+ let max_size =
+ max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
+ // Try to get a new batch
+ if let Some((mut new_entries, new_batch, span)) = queue
+ .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
+ .await
+ {
+ // Tracking metrics
+ if min_size.is_some() {
+ metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
+ .increment(1);
+ } else {
+ metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
+ .increment(1);
+ }
+
+ entries.iter_mut().for_each(|(_, entry)| {
+ // Create a new span to add the info that this entry is waiting
+ // because a new batch is being computed
+ let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
+ // Add relationships
+ span.follows_from(&entry_waiting_span);
+ entry_waiting_span.follows_from(&span);
+ // Update entry
+ entry.temp_span = Some(entry_waiting_span);
+ });
+
+ // Generate one token for this new batch to have the attention past in cache
+ let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
+ .instrument(span)
+ .await;
+ // Reset waiting counter
+ waiting_tokens = 1;
+ // Extend current batch with the new batch
+ if let Some(new_cached_batch) = new_cached_batch {
+ entries.extend(new_entries);
+ batches.push(new_cached_batch);
+ }
+ }
+
+ // Create span for this batch to add context to inference calls
+ let next_batch_size = entries.len();
+ let next_batch_span =
+ info_span!(parent: None, "batch", batch_size = next_batch_size);
+ entries.iter_mut().for_each(|(_, entry)| {
+ // Create a new span to link the batch back to this entry
+ let entry_batch_span = info_span!(parent: &entry.span, "infer");
+ // Add relationships
+ next_batch_span.follows_from(&entry_batch_span);
+ entry_batch_span.follows_from(&next_batch_span);
+ // Update entry
+ entry.temp_span = Some(entry_batch_span);
+ });
+
+ cached_batch = decode(&mut client, batches, &mut entries)
+ .instrument(next_batch_span)
+ .await;
+ waiting_tokens += 1;
+ }
+ metrics::gauge!("tgi_batch_current_size").set(0.0);
+ metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
+ }
+ }
+}
+
+#[instrument(skip_all)]
+async fn prefill(
+ client: &mut ShardedClient,
+ batch: Batch,
+ entries: &mut IntMap,
+) -> Option {
+ let start_time = Instant::now();
+ let batch_id = batch.id;
+ metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
+
+ match client.prefill(batch).await {
+ Ok((generations, next_batch, timings)) => {
+ let start_filtering_time = Instant::now();
+ // Send generated tokens and filter stopped entries
+ filter_send_generations(generations, entries);
+
+ // Filter next batch and remove requests that were stopped
+ let next_batch = filter_batch(client, next_batch, entries).await;
+
+ metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
+ .record(timings.forward.as_secs_f64());
+ metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
+ .record(timings.decode.as_secs_f64());
+ metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
+ .record(start_filtering_time.elapsed().as_secs_f64());
+ metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
+ .record(start_time.elapsed().as_secs_f64());
+ metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
+ next_batch
+ }
+ // If we have an error, we discard the whole batch
+ Err(err) => {
+ let _ = client.clear_cache(Some(batch_id)).await;
+ send_errors(err, entries);
+ metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
+ None
+ }
+ }
+}
+
+#[instrument(skip_all)]
+async fn decode(
+ client: &mut ShardedClient,
+ batches: Vec,
+ entries: &mut IntMap,
+) -> Option {
+ let start_time = Instant::now();
+ let batch_ids: Vec = batches.iter().map(|b| b.id).collect();
+ metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
+
+ match client.decode(batches).await {
+ Ok((generations, next_batch, timings)) => {
+ let start_filtering_time = Instant::now();
+ // Send generated tokens and filter stopped entries
+ filter_send_generations(generations, entries);
+
+ // Filter next batch and remove requests that were stopped
+ let next_batch = filter_batch(client, next_batch, entries).await;
+
+ if let Some(concat_duration) = timings.concat {
+ metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
+ .record(concat_duration.as_secs_f64());
+ }
+ metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
+ .record(timings.forward.as_secs_f64());
+ metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
+ .record(timings.decode.as_secs_f64());
+ metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
+ .record(start_filtering_time.elapsed().as_secs_f64());
+ metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
+ .record(start_time.elapsed().as_secs_f64());
+ metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
+ next_batch
+ }
+ // If we have an error, we discard the whole batch
+ Err(err) => {
+ for id in batch_ids {
+ let _ = client.clear_cache(Some(id)).await;
+ }
+ send_errors(err, entries);
+ metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
+ None
+ }
+ }
+}
+
+/// Filter a `batch` and remove all requests not present in `entries`
+#[instrument(skip_all)]
+async fn filter_batch(
+ client: &mut ShardedClient,
+ next_batch: Option,
+ entries: &IntMap,
+) -> Option {
+ let mut batch = next_batch?;
+
+ // No need to filter
+ if batch.size as usize == entries.len() {
+ return Some(batch);
+ }
+
+ let id = batch.id;
+
+ // Retain only requests that are still in entries
+ batch.request_ids.retain(|id| entries.contains_key(id));
+
+ if batch.request_ids.is_empty() {
+ // All requests have been filtered out
+ // Next batch is now empty
+ // Clear it from the Python shards cache
+ // We unwrap here as we need to panic since we cannot recover if this method fails
+ client.clear_cache(Some(id)).await.unwrap();
+ None
+ } else {
+ // Filter Python shard cache
+ // We unwrap here as we need to panic since we cannot recover if this method fails
+ client.filter_batch(id, batch.request_ids).await.unwrap()
+ }
+}
+
+/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
+/// and filter entries
+#[instrument(skip_all)]
+fn filter_send_generations(generations: Vec, entries: &mut IntMap) {
+ generations.into_iter().for_each(|generation| {
+ let id = generation.request_id;
+ // Get entry
+ // We can `expect` here as the request id should always be in the entries
+ let entry = entries
+ .get(&id)
+ .expect("ID not found in entries. This is a bug.");
+
+ // Create and enter a span to link this function back to the entry
+ let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
+ // Send generation responses back to the infer task
+ // If the receive an error from the Flume channel, it means that the client dropped the
+ // request and we need to stop generating hence why we unwrap_or(true)
+ let stopped = send_responses(generation, entry).inspect_err(|_err| {
+ tracing::error!("Entry response channel error.");
+ metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
+ }).unwrap_or(true);
+ if stopped {
+ entries.remove(&id).expect("ID not found in entries. This is a bug.");
+ }
+ });
+}
+
+/// Send responses through the `entry` response channel
+fn send_responses(
+ generation: Generation,
+ entry: &Entry,
+) -> Result>>> {
+ // Return directly if the channel is disconnected
+ if entry.response_tx.is_closed() {
+ metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
+ return Ok(true);
+ }
+
+ let mut stopped = false;
+
+ if let Some(prefill_tokens) = generation.prefill_tokens {
+ // Create Token objects
+ // We do that here instead of in the Python code as Rust for loops are faster
+ let prefill_tokens = prefill_tokens
+ .ids
+ .into_iter()
+ .zip(prefill_tokens.logprobs)
+ .zip(prefill_tokens.texts)
+ .map(|((id, logprob), text)| PrefillToken { id, text, logprob })
+ .collect();
+
+ // Send message
+ entry
+ .response_tx
+ .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
+ }
+
+ // Create last Token
+ let tokens_ = generation.tokens.expect("Non empty tokens in generation");
+ let n = tokens_.ids.len();
+ metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
+ let mut iterator = tokens_
+ .ids
+ .into_iter()
+ .zip(tokens_.logprobs)
+ .zip(tokens_.texts)
+ .zip(tokens_.is_special)
+ .enumerate()
+ .peekable();
+ while let Some((i, (((id, logprob), text), special))) = iterator.next() {
+ let token = Token {
+ id,
+ text,
+ logprob,
+ special,
+ };
+ let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
+ top_tokens_
+ .ids
+ .iter()
+ .zip(top_tokens_.logprobs.iter())
+ .zip(top_tokens_.texts.iter())
+ .zip(top_tokens_.is_special.iter())
+ .map(|(((&id, &logprob), text), &special)| Token {
+ id,
+ text: text.to_string(),
+ logprob,
+ special,
+ })
+ .collect()
+ } else {
+ vec![]
+ };
+ match (&generation.generated_text, iterator.peek()) {
+ (Some(generated_text), None) => {
+ // Generation has ended
+ stopped = true;
+ // Send message
+ entry.response_tx.send(Ok(InferStreamResponse::End {
+ token,
+ top_tokens,
+ generated_text: GeneratedText::from(generated_text.clone()),
+ queued: entry.queue_time,
+ start: entry.batch_time.unwrap(),
+ }))?;
+ }
+ _ => {
+ // Send message
+ entry
+ .response_tx
+ .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
+ }
+ }
+ }
+
+ Ok(stopped)
+}
+
+/// Send errors to Infer for all `entries`
+#[instrument(skip_all)]
+fn send_errors(error: ClientError, entries: &mut IntMap) {
+ entries.drain().for_each(|(_, entry)| {
+ // Create and enter a span to link this function back to the entry
+ let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
+ let err = InferError::GenerationError(error.to_string());
+ metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
+ tracing::error!("{err}");
+
+ // unwrap_or is valid here as we don't care if the receiver is gone.
+ entry
+ .response_tx
+ .send(Err(err))
+ .unwrap_or(());
+ });
+}
+
+impl From for GeneratedText {
+ fn from(value: crate::client::GeneratedText) -> Self {
+ let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
+ let finish_reason = match v2_finish_reason {
+ crate::client::FinishReason::Length => FinishReason::Length,
+ crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
+ crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
+ };
+
+ Self {
+ text: value.text,
+ generated_tokens: value.generated_tokens,
+ finish_reason,
+ seed: value.seed,
+ }
+ }
+}
diff --git a/backends/v2/src/client/grpc_client.rs b/backends/v2/src/client/grpc_client.rs
new file mode 100644
index 0000000000000000000000000000000000000000..b494352141cc1575854f5926ed70c17d6275dd45
--- /dev/null
+++ b/backends/v2/src/client/grpc_client.rs
@@ -0,0 +1,257 @@
+/// Single shard Client
+use crate::client::pb;
+use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
+use grpc_metadata::InjectTelemetryContext;
+use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
+use pb::generate::v2::*;
+use std::cmp::min;
+use std::time::Duration;
+use tonic::transport::{Channel, Uri};
+use tracing::instrument;
+
+/// Text Generation Inference gRPC client
+#[derive(Debug, Clone)]
+pub struct Client {
+ stub: TextGenerationServiceClient,
+}
+
+impl Client {
+ /// Returns a client connected to the given url
+ #[allow(dead_code)]
+ pub async fn connect(uri: Uri) -> Result {
+ let channel = Channel::builder(uri).connect().await?;
+
+ Ok(Self {
+ stub: TextGenerationServiceClient::new(channel),
+ })
+ }
+
+ /// Returns a client connected to the given unix socket
+ pub async fn connect_uds(path: String) -> Result {
+ let channel = Channel::from_shared("http://[::]:50051".to_string())
+ .unwrap()
+ .connect_with_connector(tower::service_fn(move |_: Uri| {
+ tokio::net::UnixStream::connect(path.clone())
+ }))
+ .await?;
+
+ Ok(Self {
+ stub: TextGenerationServiceClient::new(channel),
+ })
+ }
+
+ /// Returns a list of uris or unix sockets of all shards
+ #[instrument(skip(self))]
+ pub async fn service_discovery(&mut self) -> Result> {
+ let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
+ let response = self.stub.service_discovery(request).await.map_err(|_| {
+ ClientError::Connection("Server does not support v2 interface".to_string())
+ })?;
+ let urls = response
+ .into_inner()
+ .urls
+ .into_iter()
+ // Remove unix socket prefix
+ .map(|url| match url.strip_prefix("unix://") {
+ None => url,
+ Some(stripped_url) => stripped_url.to_string(),
+ })
+ .collect();
+ Ok(urls)
+ }
+
+ /// Get model info
+ #[instrument(skip(self))]
+ pub async fn info(&mut self) -> Result {
+ let request = tonic::Request::new(InfoRequest {}).inject_context();
+ let response = self.stub.info(request).await?.into_inner();
+ Ok(response)
+ }
+
+ /// Get model health
+ #[instrument(skip(self))]
+ pub async fn health(&mut self) -> Result {
+ let request = tonic::Request::new(HealthRequest {}).inject_context();
+ let response = self.stub.health(request).await?.into_inner();
+ Ok(response)
+ }
+
+ /// Clear the past generations cache
+ #[instrument(skip(self))]
+ pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> {
+ let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
+ self.stub.clear_cache(request).await?;
+ Ok(())
+ }
+
+ /// Filter a cached batch
+ #[instrument(skip(self))]
+ pub async fn filter_batch(
+ &mut self,
+ batch_id: u64,
+ request_ids: Vec,
+ ) -> Result