Dockerfile_amd 10.9 KB
Newer Older
jixx's avatar
init  
jixx committed
1
# Rust builder
jixx's avatar
jixx committed
2
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
jixx's avatar
init  
jixx committed
3
4
5
6
7
8
9
10
11
12
13
WORKDIR /usr/src

ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

FROM chef AS planner
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
jixx's avatar
jixx committed
14
COPY backends backends
jixx's avatar
init  
jixx committed
15
16
17
18
19
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder

jixx's avatar
jixx committed
20
21
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
    python3.11-dev
jixx's avatar
init  
jixx committed
22
23
24
25
26
27
28
29
30
31
32
33
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
    curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
    unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
    unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
    rm -f $PROTOC_ZIP

COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --profile release-opt --recipe-path recipe.json

ARG GIT_SHA
ARG DOCKER_LABEL

jixx's avatar
jixx committed
34
COPY Cargo.lock Cargo.lock
jixx's avatar
init  
jixx committed
35
36
37
38
39
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
jixx's avatar
jixx committed
40
COPY backends backends
jixx's avatar
init  
jixx committed
41
COPY launcher launcher
jixx's avatar
jixx committed
42
RUN cargo build --profile release-opt --frozen
jixx's avatar
init  
jixx committed
43
44

# Text Generation Inference base image for RoCm
jixx's avatar
jixx committed
45
FROM rocm/dev-ubuntu-22.04:6.2 AS base
jixx's avatar
init  
jixx committed
46
47
48
49
50
51
52
53

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
    build-essential \
    ca-certificates \
    ccache \
    curl \
    git \
    make \
jixx's avatar
jixx committed
54
    libmsgpack-dev \
jixx's avatar
init  
jixx committed
55
    libssl-dev \
jixx's avatar
jixx committed
56
    llvm-dev \
jixx's avatar
init  
jixx committed
57
58
59
60
61
    g++ \
    # Needed to build VLLM & flash.
    rocthrust-dev \
    hipsparse-dev \
    hipblas-dev \
jixx's avatar
jixx committed
62
    hipcub-dev \
jixx's avatar
init  
jixx committed
63
64
    rocblas-dev \
    hiprand-dev \
jixx's avatar
jixx committed
65
    hipfft-dev \
jixx's avatar
init  
jixx committed
66
67
68
69
70
    rocrand-dev \
    miopen-hip-dev \
    hipsolver-dev \
    rccl-dev \
    cmake \
jixx's avatar
jixx committed
71
    python3.11-venv && \
jixx's avatar
init  
jixx committed
72
73
74
75
    rm -rf /var/lib/apt/lists/*

# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
jixx's avatar
jixx committed
76
ARG PYTHON_VERSION='3.11.10'
jixx's avatar
init  
jixx committed
77
78
# Automatically set by buildx
ARG TARGETPLATFORM
jixx's avatar
jixx committed
79
80
81
ENV PATH=/opt/conda/bin:$PATH

ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
jixx's avatar
init  
jixx committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95

# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
         "linux/arm64")  MAMBA_ARCH=aarch64  ;; \
         *)              MAMBA_ARCH=x86_64   ;; \
    esac && \
    curl -fsSL -v -o ~/mambaforge.sh -O  "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
    bash ~/mambaforge.sh -b -p /opt/conda && \
    mamba init && \
    rm ~/mambaforge.sh

jixx's avatar
jixx committed
96
97
98
99
100
101
102
103
104
# RUN conda install intel::mkl-static intel::mkl-include
# Install pytorch
# On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \
         "linux/arm64")  exit 1 ;; \
         *)              /opt/conda/bin/conda update -y conda &&  \
                         /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
    esac && \
    /opt/conda/bin/conda clean -ya
jixx's avatar
init  
jixx committed
105

jixx's avatar
jixx committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Install flash-attention, torch dependencies
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*

RUN conda install mkl=2021
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/


ARG COMMON_WORKDIR=/
WORKDIR ${COMMON_WORKDIR}


# Install HIPBLASLt
FROM base AS build_hipblaslt
ARG HIPBLASLT_BRANCH="e6da924"
RUN git clone https://github.com/ROCm/hipBLASLt.git \
    && cd hipBLASLt \
    && git checkout ${HIPBLASLT_BRANCH} \
    && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
    && cd build/release \
    && make package

FROM scratch AS export_hipblaslt
ARG COMMON_WORKDIR
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /

# RCCL build stages
FROM base AS build_rccl
ARG RCCL_BRANCH="rocm-6.2.0"
RUN git clone https://github.com/ROCm/rccl \
    && cd rccl \
    && git checkout ${RCCL_BRANCH} \
    && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
FROM scratch AS export_rccl
ARG COMMON_WORKDIR
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /

# Triton build stages
FROM base AS build_triton
ARG TRITON_BRANCH="e192dba"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
    && cd triton \
    && git checkout ${TRITON_BRANCH} \
    && cd python \
    && python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch AS export_triton
ARG COMMON_WORKDIR
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /

# # AMD-SMI build stages
FROM base AS build_amdsmi
RUN cd /opt/rocm/share/amd_smi \
    && pip wheel . --wheel-dir=dist
FROM scratch AS export_amdsmi
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /


FROM base as build_pytorch

RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
    if ls /install/*.deb; then \
        dpkg -i /install/*.deb \
        && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
        && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
    fi

ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
jixx's avatar
init  
jixx committed
173
174
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"

jixx's avatar
jixx committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# A commit to fix the output scaling factor issue in _scaled_mm
# Not yet in 2.5.0-rc1
ARG PYTORCH_BRANCH="cedc116"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"

RUN git clone ${PYTORCH_REPO} pytorch \
    && cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
    && pip install -r requirements.txt --no-cache-dir  \
    && python tools/amd_build/build_amd.py \
    && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch as export_pytorch
ARG COMMON_WORKDIR
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /

FROM base AS install_deps

ARG COMMON_WORKDIR

# Install hipblaslt
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
    if ls /install/*.deb; then \
        dpkg -i /install/*.deb \
        && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
        && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
    fi

RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
    if ls /install/*.deb; then \
        dpkg -i /install/*.deb \
        # RCCL needs to be installed twice
        && dpkg -i /install/*.deb \
        && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
        && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
    fi

RUN --mount=type=bind,from=export_triton,src=/,target=/install \
    if ls /install/*.whl; then \
        # Preemptively uninstall to prevent pip same-version no-installs
        pip uninstall -y triton \
        && pip install /install/*.whl; \
    fi

RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
    # Preemptively uninstall to prevent pip same-version no-installs
    pip uninstall -y amdsmi \
    && pip install /install/*.whl;

RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
    if ls /install/*.whl; then \
        # Preemptively uninstall to prevent pip same-version no-installs
        pip uninstall -y torch torchvision \
        && pip install /install/*.whl; \
    fi

FROM install_deps AS kernel-builder
jixx's avatar
init  
jixx committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269

# # Build vllm kernels
FROM kernel-builder AS vllm-builder
WORKDIR /usr/src

COPY server/Makefile-vllm Makefile

# Build specific version of vllm
RUN make build-vllm-rocm

# Build Flash Attention v2 kernels
FROM kernel-builder AS flash-att-v2-builder
WORKDIR /usr/src

COPY server/Makefile-flash-att-v2 Makefile

# Build specific version of flash attention v2
RUN make build-flash-attention-v2-rocm

# Build Transformers CUDA kernels (gpt-neox and bloom)
FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ .
RUN python setup.py build

# Build exllama kernels
FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .

RUN python setup.py build

# Build exllama v2 kernels
FROM kernel-builder AS exllamav2-kernels-builder
WORKDIR /usr/src
COPY server/exllamav2_kernels/ .

RUN python setup.py build

jixx's avatar
jixx committed
270
FROM install_deps AS base-copy
jixx's avatar
init  
jixx committed
271
272

# Text Generation Inference base env
jixx's avatar
jixx committed
273
ENV HF_HOME=/data \
jixx's avatar
init  
jixx committed
274
275
276
277
    HF_HUB_ENABLE_HF_TRANSFER=1 \
    PORT=80

# Copy builds artifacts from vllm builder
jixx's avatar
jixx committed
278
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
jixx's avatar
init  
jixx committed
279
280

# Copy build artifacts from flash attention v2 builder
jixx's avatar
jixx committed
281
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
jixx's avatar
init  
jixx committed
282
283

# Copy build artifacts from custom kernels builder
jixx's avatar
jixx committed
284
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
jixx's avatar
init  
jixx committed
285
286

# Copy build artifacts from exllama kernels builder
jixx's avatar
jixx committed
287
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
jixx's avatar
init  
jixx committed
288
289

# Copy build artifacts from exllamav2 kernels builder
jixx's avatar
jixx committed
290
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
jixx's avatar
init  
jixx committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306

# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
    make gen-server && \
    pip install -r requirements_rocm.txt && \
    pip install ".[accelerate, peft, outlines]" --no-cache-dir

# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
jixx's avatar
jixx committed
307
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
jixx's avatar
init  
jixx committed
308
309
310
311
312
313
314
315
316
317
318
319

# AWS Sagemaker compatible image
FROM base AS sagemaker

COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

ENTRYPOINT ["./entrypoint.sh"]

# Final image
FROM base-copy

jixx's avatar
jixx committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1

# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
# However, Triton requires a tunning for each prompt length, which is prohibitive.
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV ROCM_USE_SKINNY_GEMM=1

jixx's avatar
init  
jixx committed
334
335
336
337
338
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]