Dockerfile_amd 10.9 KB
Newer Older
fxmarty's avatar
fxmarty committed
1
# Rust builder
2
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
fxmarty's avatar
fxmarty committed
3
4
5
6
WORKDIR /usr/src

ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

7
FROM chef AS planner
ur4t's avatar
ur4t committed
8
COPY Cargo.lock Cargo.lock
fxmarty's avatar
fxmarty committed
9
10
11
12
13
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
Nicolas Patry's avatar
Nicolas Patry committed
14
COPY backends backends
fxmarty's avatar
fxmarty committed
15
16
17
18
19
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder

Nicolas Patry's avatar
Nicolas Patry committed
20
21
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
    python3.11-dev
fxmarty's avatar
fxmarty committed
22
23
24
25
26
27
28
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
    curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
    unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
    unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
    rm -f $PROTOC_ZIP

COPY --from=planner /usr/src/recipe.json recipe.json
29
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
fxmarty's avatar
fxmarty committed
30

Nicolas Patry's avatar
Nicolas Patry committed
31
32
33
ARG GIT_SHA
ARG DOCKER_LABEL

34
COPY Cargo.lock Cargo.lock
fxmarty's avatar
fxmarty committed
35
36
37
38
39
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
Nicolas Patry's avatar
Nicolas Patry committed
40
COPY backends backends
fxmarty's avatar
fxmarty committed
41
COPY launcher launcher
42
RUN cargo build --profile release-opt --frozen
fxmarty's avatar
fxmarty committed
43
44

# Text Generation Inference base image for RoCm
45
FROM rocm/dev-ubuntu-22.04:6.2 AS base
fxmarty's avatar
fxmarty committed
46
47
48
49
50
51
52
53

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
    build-essential \
    ca-certificates \
    ccache \
    curl \
    git \
    make \
54
    libmsgpack-dev \
fxmarty's avatar
fxmarty committed
55
    libssl-dev \
56
    llvm-dev \
fxmarty's avatar
fxmarty committed
57
58
59
60
    g++ \
    # Needed to build VLLM & flash.
    rocthrust-dev \
    hipsparse-dev \
fxmarty's avatar
fxmarty committed
61
    hipblas-dev \
62
    hipcub-dev \
fxmarty's avatar
fxmarty committed
63
64
    rocblas-dev \
    hiprand-dev \
65
    hipfft-dev \
fxmarty's avatar
fxmarty committed
66
67
68
69
70
    rocrand-dev \
    miopen-hip-dev \
    hipsolver-dev \
    rccl-dev \
    cmake \
71
    python3.11-venv && \
fxmarty's avatar
fxmarty committed
72
73
74
75
    rm -rf /var/lib/apt/lists/*

# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
Nicolas Patry's avatar
Nicolas Patry committed
76
ARG PYTHON_VERSION='3.11.10'
fxmarty's avatar
fxmarty committed
77
78
# Automatically set by buildx
ARG TARGETPLATFORM
79
80
81
ENV PATH=/opt/conda/bin:$PATH

ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
fxmarty's avatar
fxmarty committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95

# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
         "linux/arm64")  MAMBA_ARCH=aarch64  ;; \
         *)              MAMBA_ARCH=x86_64   ;; \
    esac && \
    curl -fsSL -v -o ~/mambaforge.sh -O  "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
    bash ~/mambaforge.sh -b -p /opt/conda && \
    mamba init && \
    rm ~/mambaforge.sh

Nicolas Patry's avatar
Nicolas Patry committed
96
97
98
99
100
101
102
103
104
# RUN conda install intel::mkl-static intel::mkl-include
# Install pytorch
# On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \
         "linux/arm64")  exit 1 ;; \
         *)              /opt/conda/bin/conda update -y conda &&  \
                         /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
    esac && \
    /opt/conda/bin/conda clean -ya
fxmarty's avatar
fxmarty committed
105

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Install flash-attention, torch dependencies
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*

RUN conda install mkl=2021
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/


ARG COMMON_WORKDIR=/
WORKDIR ${COMMON_WORKDIR}


# Install HIPBLASLt
FROM base AS build_hipblaslt
ARG HIPBLASLT_BRANCH="e6da924"
RUN git clone https://github.com/ROCm/hipBLASLt.git \
    && cd hipBLASLt \
    && git checkout ${HIPBLASLT_BRANCH} \
    && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
    && cd build/release \
    && make package

FROM scratch AS export_hipblaslt
ARG COMMON_WORKDIR
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /

# RCCL build stages
FROM base AS build_rccl
ARG RCCL_BRANCH="rocm-6.2.0"
RUN git clone https://github.com/ROCm/rccl \
    && cd rccl \
    && git checkout ${RCCL_BRANCH} \
    && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
FROM scratch AS export_rccl
ARG COMMON_WORKDIR
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /

# Triton build stages
FROM base AS build_triton
ARG TRITON_BRANCH="e192dba"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
    && cd triton \
    && git checkout ${TRITON_BRANCH} \
    && cd python \
    && python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch AS export_triton
ARG COMMON_WORKDIR
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /

# # AMD-SMI build stages
FROM base AS build_amdsmi
RUN cd /opt/rocm/share/amd_smi \
    && pip wheel . --wheel-dir=dist
FROM scratch AS export_amdsmi
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /


FROM base as build_pytorch

RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
    if ls /install/*.deb; then \
        dpkg -i /install/*.deb \
        && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
        && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
    fi

ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
fxmarty's avatar
fxmarty committed
173
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
fxmarty's avatar
fxmarty committed
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# A commit to fix the output scaling factor issue in _scaled_mm
# Not yet in 2.5.0-rc1
ARG PYTORCH_BRANCH="cedc116"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"

RUN git clone ${PYTORCH_REPO} pytorch \
    && cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
    && pip install -r requirements.txt --no-cache-dir  \
    && python tools/amd_build/build_amd.py \
    && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch as export_pytorch
ARG COMMON_WORKDIR
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /

FROM base AS install_deps

ARG COMMON_WORKDIR

# Install hipblaslt
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
    if ls /install/*.deb; then \
        dpkg -i /install/*.deb \
        && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
        && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
    fi

RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
    if ls /install/*.deb; then \
        dpkg -i /install/*.deb \
        # RCCL needs to be installed twice
        && dpkg -i /install/*.deb \
        && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
        && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
    fi

RUN --mount=type=bind,from=export_triton,src=/,target=/install \
    if ls /install/*.whl; then \
        # Preemptively uninstall to prevent pip same-version no-installs
        pip uninstall -y triton \
        && pip install /install/*.whl; \
    fi

RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
    # Preemptively uninstall to prevent pip same-version no-installs
    pip uninstall -y amdsmi \
    && pip install /install/*.whl;

RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
    if ls /install/*.whl; then \
        # Preemptively uninstall to prevent pip same-version no-installs
        pip uninstall -y torch torchvision \
        && pip install /install/*.whl; \
    fi

FROM install_deps AS kernel-builder
fxmarty's avatar
fxmarty committed
231

fxmarty's avatar
fxmarty committed
232
# # Build vllm kernels
fxmarty's avatar
fxmarty committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
FROM kernel-builder AS vllm-builder
WORKDIR /usr/src

COPY server/Makefile-vllm Makefile

# Build specific version of vllm
RUN make build-vllm-rocm

# Build Flash Attention v2 kernels
FROM kernel-builder AS flash-att-v2-builder
WORKDIR /usr/src

COPY server/Makefile-flash-att-v2 Makefile

# Build specific version of flash attention v2
RUN make build-flash-attention-v2-rocm

# Build Transformers CUDA kernels (gpt-neox and bloom)
251
FROM kernel-builder AS custom-kernels-builder
fxmarty's avatar
fxmarty committed
252
253
WORKDIR /usr/src
COPY server/custom_kernels/ .
fxmarty's avatar
fxmarty committed
254
RUN python setup.py build
fxmarty's avatar
fxmarty committed
255

fxmarty's avatar
fxmarty committed
256
# Build exllama kernels
257
FROM kernel-builder AS exllama-kernels-builder
fxmarty's avatar
fxmarty committed
258
259
260
WORKDIR /usr/src
COPY server/exllama_kernels/ .

fxmarty's avatar
fxmarty committed
261
RUN python setup.py build
fxmarty's avatar
fxmarty committed
262
263

# Build exllama v2 kernels
264
FROM kernel-builder AS exllamav2-kernels-builder
fxmarty's avatar
fxmarty committed
265
266
267
WORKDIR /usr/src
COPY server/exllamav2_kernels/ .

fxmarty's avatar
fxmarty committed
268
RUN python setup.py build
fxmarty's avatar
fxmarty committed
269

270
FROM install_deps AS base-copy
fxmarty's avatar
fxmarty committed
271
272

# Text Generation Inference base env
273
ENV HF_HOME=/data \
fxmarty's avatar
fxmarty committed
274
275
276
277
    HF_HUB_ENABLE_HF_TRANSFER=1 \
    PORT=80

# Copy builds artifacts from vllm builder
Nicolas Patry's avatar
Nicolas Patry committed
278
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
fxmarty's avatar
fxmarty committed
279
280

# Copy build artifacts from flash attention v2 builder
Nicolas Patry's avatar
Nicolas Patry committed
281
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
fxmarty's avatar
fxmarty committed
282
283

# Copy build artifacts from custom kernels builder
Nicolas Patry's avatar
Nicolas Patry committed
284
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
fxmarty's avatar
fxmarty committed
285
286

# Copy build artifacts from exllama kernels builder
Nicolas Patry's avatar
Nicolas Patry committed
287
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
fxmarty's avatar
fxmarty committed
288
289

# Copy build artifacts from exllamav2 kernels builder
Nicolas Patry's avatar
Nicolas Patry committed
290
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
fxmarty's avatar
fxmarty committed
291
292
293
294
295
296
297
298

# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
    make gen-server && \
    pip install -r requirements_rocm.txt && \
OlivierDehaene's avatar
OlivierDehaene committed
299
    pip install ".[accelerate, peft, outlines]" --no-cache-dir
fxmarty's avatar
fxmarty committed
300
301

# Install benchmarker
302
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
fxmarty's avatar
fxmarty committed
303
# Install router
304
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
fxmarty's avatar
fxmarty committed
305
# Install launcher
306
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
Nicolas Patry's avatar
Nicolas Patry committed
307
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
fxmarty's avatar
fxmarty committed
308
309

# AWS Sagemaker compatible image
310
FROM base AS sagemaker
fxmarty's avatar
fxmarty committed
311

fxmarty's avatar
fxmarty committed
312
313
314
315
316
317
318
319
COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

ENTRYPOINT ["./entrypoint.sh"]

# Final image
FROM base-copy

320
321
322
323
324
325
326
327
328
329
330
331
332
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1

# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
# However, Triton requires a tunning for each prompt length, which is prohibitive.
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV ROCM_USE_SKINNY_GEMM=1

fxmarty's avatar
fxmarty committed
333
334
335
336
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

ENTRYPOINT ["/tgi-entrypoint.sh"]
fxmarty's avatar
fxmarty committed
337
CMD ["--json-output"]