Dockerfile 9.94 KB
Newer Older
jixx's avatar
init  
jixx committed
1
# Rust builder
jixx's avatar
jixx committed
2
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
jixx's avatar
init  
jixx committed
3
4
5
6
7
8
9
10
11
12
13
WORKDIR /usr/src

ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

FROM chef AS planner
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
jixx's avatar
jixx committed
14
COPY backends backends
jixx's avatar
init  
jixx committed
15
COPY launcher launcher
jixx's avatar
jixx committed
16

jixx's avatar
init  
jixx committed
17
18
19
20
RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder

jixx's avatar
jixx committed
21
22
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
    python3.11-dev
jixx's avatar
init  
jixx committed
23
24
25
26
27
28
29
30
31
32
33
34
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
    curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
    unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
    unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
    rm -f $PROTOC_ZIP

COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --profile release-opt --recipe-path recipe.json

ARG GIT_SHA
ARG DOCKER_LABEL

jixx's avatar
jixx committed
35
COPY Cargo.lock Cargo.lock
jixx's avatar
init  
jixx committed
36
37
38
39
40
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
jixx's avatar
jixx committed
41
COPY backends backends
jixx's avatar
init  
jixx committed
42
COPY launcher launcher
jixx's avatar
jixx committed
43
RUN cargo build --profile release-opt --frozen
jixx's avatar
init  
jixx committed
44
45
46

# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
jixx's avatar
jixx committed
47
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
jixx's avatar
init  
jixx committed
48

jixx's avatar
jixx committed
49
50
51
52
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.4.0

ARG PYTHON_VERSION=3.11
jixx's avatar
init  
jixx committed
53
# Keep in sync with `server/pyproject.toml
jixx's avatar
jixx committed
54
ARG CUDA_VERSION=12.4
jixx's avatar
init  
jixx committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
ARG MAMBA_VERSION=24.3.0-0
ARG CUDA_CHANNEL=nvidia
ARG INSTALL_CHANNEL=pytorch
# Automatically set by buildx
ARG TARGETPLATFORM

ENV PATH /opt/conda/bin:$PATH

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
        build-essential \
        ca-certificates \
        ccache \
        curl \
        git && \
        rm -rf /var/lib/apt/lists/*

# Install conda
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
         "linux/arm64")  MAMBA_ARCH=aarch64  ;; \
         *)              MAMBA_ARCH=x86_64   ;; \
    esac && \
    curl -fsSL -v -o ~/mambaforge.sh -O  "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
    bash ~/mambaforge.sh -b -p /opt/conda && \
    rm ~/mambaforge.sh

# Install pytorch
# On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \
         "linux/arm64")  exit 1 ;; \
         *)              /opt/conda/bin/conda update -y conda &&  \
                         /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)"  ;; \
    esac && \
    /opt/conda/bin/conda clean -ya

# CUDA kernels builder image
FROM pytorch-install AS kernel-builder

ARG MAX_JOBS=8
jixx's avatar
jixx committed
95
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0+PTX"
jixx's avatar
init  
jixx committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
        ninja-build cmake \
        && rm -rf /var/lib/apt/lists/*

# Build Flash Attention CUDA kernels
FROM kernel-builder AS flash-att-builder

WORKDIR /usr/src

COPY server/Makefile-flash-att Makefile

# Build specific version of flash attention
RUN make build-flash-attention

# Build Flash Attention v2 CUDA kernels
FROM kernel-builder AS flash-att-v2-builder

WORKDIR /usr/src

COPY server/Makefile-flash-att-v2 Makefile

# Build specific version of flash attention v2
RUN make build-flash-attention-v2-cuda

# Build Transformers exllama kernels
FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .

jixx's avatar
jixx committed
126
RUN python setup.py build
jixx's avatar
init  
jixx committed
127
128
129
130

# Build Transformers exllama kernels
FROM kernel-builder AS exllamav2-kernels-builder
WORKDIR /usr/src
jixx's avatar
jixx committed
131
COPY server/Makefile-exllamav2/ Makefile
jixx's avatar
init  
jixx committed
132
133

# Build specific version of transformers
jixx's avatar
jixx committed
134
RUN make build-exllamav2
jixx's avatar
init  
jixx committed
135
136
137
138
139
140

# Build Transformers awq kernels
FROM kernel-builder AS awq-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-awq Makefile
# Build specific version of transformers
jixx's avatar
jixx committed
141
RUN make build-awq
jixx's avatar
init  
jixx committed
142
143
144
145
146
147

# Build eetq kernels
FROM kernel-builder AS eetq-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-eetq Makefile
# Build specific version of transformers
jixx's avatar
jixx committed
148
RUN make build-eetq
jixx's avatar
init  
jixx committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

# Build Lorax Punica kernels
FROM kernel-builder AS lorax-punica-builder
WORKDIR /usr/src
COPY server/Makefile-lorax-punica Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica

# Build Transformers CUDA kernels
FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ .
# Build specific version of transformers
RUN python setup.py build

# Build vllm CUDA kernels
FROM kernel-builder AS vllm-builder

WORKDIR /usr/src

ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"

COPY server/Makefile-vllm Makefile

# Build specific version of vllm
RUN make build-vllm-cuda

# Build mamba kernels
FROM kernel-builder AS mamba-builder
WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile
RUN make build-all

jixx's avatar
jixx committed
182
183
184
185
186
187
# Build flashinfer
FROM kernel-builder AS flashinfer-builder
WORKDIR /usr/src
COPY server/Makefile-flashinfer Makefile
RUN make install-flashinfer

jixx's avatar
init  
jixx committed
188
189
190
191
192
193
194
195
# Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base

# Conda env
ENV PATH=/opt/conda/bin:$PATH \
    CONDA_PREFIX=/opt/conda

# Text Generation Inference base env
jixx's avatar
jixx committed
196
ENV HF_HOME=/data \
jixx's avatar
init  
jixx committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    HF_HUB_ENABLE_HF_TRANSFER=1 \
    PORT=80

WORKDIR /usr/src

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
        libssl-dev \
        ca-certificates \
        make \
        curl \
        git \
        && rm -rf /var/lib/apt/lists/*

# Copy conda with PyTorch installed
COPY --from=pytorch-install /opt/conda /opt/conda

# Copy build artifacts from flash attention builder
jixx's avatar
jixx committed
214
215
216
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
jixx's avatar
init  
jixx committed
217
218

# Copy build artifacts from flash attention v2 builder
jixx's avatar
jixx committed
219
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /opt/conda/lib/python3.11/site-packages
jixx's avatar
init  
jixx committed
220
221

# Copy build artifacts from custom kernels builder
jixx's avatar
jixx committed
222
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
jixx's avatar
init  
jixx committed
223
# Copy build artifacts from exllama kernels builder
jixx's avatar
jixx committed
224
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
jixx's avatar
init  
jixx committed
225
# Copy build artifacts from exllamav2 kernels builder
jixx's avatar
jixx committed
226
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
jixx's avatar
init  
jixx committed
227
# Copy build artifacts from awq kernels builder
jixx's avatar
jixx committed
228
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
jixx's avatar
init  
jixx committed
229
# Copy build artifacts from eetq kernels builder
jixx's avatar
jixx committed
230
231
232
233
234
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
jixx's avatar
init  
jixx committed
235
# Copy build artifacts from mamba builder
jixx's avatar
jixx committed
236
237
238
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/
jixx's avatar
init  
jixx committed
239
240
241
242
243
244
245
246
247
248
249

# Install flash-attention dependencies
RUN pip install einops --no-cache-dir

# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
    make gen-server && \
    pip install -r requirements_cuda.txt && \
jixx's avatar
jixx committed
250
251
252
253
254
255
256
257
258
    pip install ".[bnb, accelerate, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \
    pip install nvidia-nccl-cu12==2.22.3

ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
# Required to find libpython within the rust binaries
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# This is needed because exl2 tries to load flash-attn
# And fails with our builds.
ENV EXLLAMA_NO_FLASH_ATTN=1
jixx's avatar
init  
jixx committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

# Deps before the binaries
# The binaries change on every build given we burn the SHA into them
# The deps change less often.
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
        build-essential \
        g++ \
        && rm -rf /var/lib/apt/lists/*

# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher


# AWS Sagemaker compatible image
FROM base AS sagemaker

COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

ENTRYPOINT ["./entrypoint.sh"]

# Final image
FROM base

COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

ENTRYPOINT ["/tgi-entrypoint.sh"]
# CMD ["--json-output"]