"vscode:/vscode.git/clone" did not exist on "fa819533801c4438cc533abb95fc3fe21cec73b6"
Dockerfile_amd 10.9 KB
Newer Older
fxmarty's avatar
fxmarty committed
1
# Rust builder
2
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
fxmarty's avatar
fxmarty committed
3
4
5
6
WORKDIR /usr/src

ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

7
FROM chef AS planner
ur4t's avatar
ur4t committed
8
COPY Cargo.lock Cargo.lock
fxmarty's avatar
fxmarty committed
9
10
11
12
13
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
Nicolas Patry's avatar
Nicolas Patry committed
14
COPY backends backends
fxmarty's avatar
fxmarty committed
15
16
17
18
19
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder

Nicolas Patry's avatar
Nicolas Patry committed
20
21
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
    python3.11-dev
fxmarty's avatar
fxmarty committed
22
23
24
25
26
27
28
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
    curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
    unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
    unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
    rm -f $PROTOC_ZIP

COPY --from=planner /usr/src/recipe.json recipe.json
29
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
fxmarty's avatar
fxmarty committed
30

Nicolas Patry's avatar
Nicolas Patry committed
31
32
33
ARG GIT_SHA
ARG DOCKER_LABEL

fxmarty's avatar
fxmarty committed
34
35
36
37
38
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
Nicolas Patry's avatar
Nicolas Patry committed
39
COPY backends backends
fxmarty's avatar
fxmarty committed
40
COPY launcher launcher
41
RUN cargo build --profile release-opt
fxmarty's avatar
fxmarty committed
42
43

# Text Generation Inference base image for RoCm
44
FROM rocm/dev-ubuntu-22.04:6.2 AS base
fxmarty's avatar
fxmarty committed
45
46
47
48
49
50
51
52

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
    build-essential \
    ca-certificates \
    ccache \
    curl \
    git \
    make \
53
    libmsgpack-dev \
fxmarty's avatar
fxmarty committed
54
    libssl-dev \
55
    llvm-dev \
fxmarty's avatar
fxmarty committed
56
57
58
59
    g++ \
    # Needed to build VLLM & flash.
    rocthrust-dev \
    hipsparse-dev \
fxmarty's avatar
fxmarty committed
60
    hipblas-dev \
61
    hipcub-dev \
fxmarty's avatar
fxmarty committed
62
63
    rocblas-dev \
    hiprand-dev \
64
    hipfft-dev \
fxmarty's avatar
fxmarty committed
65
66
67
68
69
    rocrand-dev \
    miopen-hip-dev \
    hipsolver-dev \
    rccl-dev \
    cmake \
70
    python3.11-venv && \
fxmarty's avatar
fxmarty committed
71
72
73
74
    rm -rf /var/lib/apt/lists/*

# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
Nicolas Patry's avatar
Nicolas Patry committed
75
ARG PYTHON_VERSION='3.11.10'
fxmarty's avatar
fxmarty committed
76
77
# Automatically set by buildx
ARG TARGETPLATFORM
78
79
80
ENV PATH=/opt/conda/bin:$PATH

ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
fxmarty's avatar
fxmarty committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94

# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
         "linux/arm64")  MAMBA_ARCH=aarch64  ;; \
         *)              MAMBA_ARCH=x86_64   ;; \
    esac && \
    curl -fsSL -v -o ~/mambaforge.sh -O  "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
    bash ~/mambaforge.sh -b -p /opt/conda && \
    mamba init && \
    rm ~/mambaforge.sh

Nicolas Patry's avatar
Nicolas Patry committed
95
96
97
98
99
100
101
102
103
# RUN conda install intel::mkl-static intel::mkl-include
# Install pytorch
# On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \
         "linux/arm64")  exit 1 ;; \
         *)              /opt/conda/bin/conda update -y conda &&  \
                         /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
    esac && \
    /opt/conda/bin/conda clean -ya
fxmarty's avatar
fxmarty committed
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Install flash-attention, torch dependencies
RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*

RUN conda install mkl=2021
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/


ARG COMMON_WORKDIR=/
WORKDIR ${COMMON_WORKDIR}


# Install HIPBLASLt
FROM base AS build_hipblaslt
ARG HIPBLASLT_BRANCH="e6da924"
RUN git clone https://github.com/ROCm/hipBLASLt.git \
    && cd hipBLASLt \
    && git checkout ${HIPBLASLT_BRANCH} \
    && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
    && cd build/release \
    && make package

FROM scratch AS export_hipblaslt
ARG COMMON_WORKDIR
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /

# RCCL build stages
FROM base AS build_rccl
ARG RCCL_BRANCH="rocm-6.2.0"
RUN git clone https://github.com/ROCm/rccl \
    && cd rccl \
    && git checkout ${RCCL_BRANCH} \
    && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
FROM scratch AS export_rccl
ARG COMMON_WORKDIR
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /

# Triton build stages
FROM base AS build_triton
ARG TRITON_BRANCH="e192dba"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
    && cd triton \
    && git checkout ${TRITON_BRANCH} \
    && cd python \
    && python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch AS export_triton
ARG COMMON_WORKDIR
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /

# # AMD-SMI build stages
FROM base AS build_amdsmi
RUN cd /opt/rocm/share/amd_smi \
    && pip wheel . --wheel-dir=dist
FROM scratch AS export_amdsmi
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /


FROM base as build_pytorch

RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
    if ls /install/*.deb; then \
        dpkg -i /install/*.deb \
        && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
        && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
    fi

ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
fxmarty's avatar
fxmarty committed
172
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
fxmarty's avatar
fxmarty committed
173

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# A commit to fix the output scaling factor issue in _scaled_mm
# Not yet in 2.5.0-rc1
ARG PYTORCH_BRANCH="cedc116"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"

RUN git clone ${PYTORCH_REPO} pytorch \
    && cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
    && pip install -r requirements.txt --no-cache-dir  \
    && python tools/amd_build/build_amd.py \
    && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch as export_pytorch
ARG COMMON_WORKDIR
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /

FROM base AS install_deps

ARG COMMON_WORKDIR

# Install hipblaslt
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
    if ls /install/*.deb; then \
        dpkg -i /install/*.deb \
        && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
        && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
    fi

RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
    if ls /install/*.deb; then \
        dpkg -i /install/*.deb \
        # RCCL needs to be installed twice
        && dpkg -i /install/*.deb \
        && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
        && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
    fi

RUN --mount=type=bind,from=export_triton,src=/,target=/install \
    if ls /install/*.whl; then \
        # Preemptively uninstall to prevent pip same-version no-installs
        pip uninstall -y triton \
        && pip install /install/*.whl; \
    fi

RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
    # Preemptively uninstall to prevent pip same-version no-installs
    pip uninstall -y amdsmi \
    && pip install /install/*.whl;

RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
    if ls /install/*.whl; then \
        # Preemptively uninstall to prevent pip same-version no-installs
        pip uninstall -y torch torchvision \
        && pip install /install/*.whl; \
    fi

FROM install_deps AS kernel-builder
fxmarty's avatar
fxmarty committed
230

fxmarty's avatar
fxmarty committed
231
# # Build vllm kernels
fxmarty's avatar
fxmarty committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
FROM kernel-builder AS vllm-builder
WORKDIR /usr/src

COPY server/Makefile-vllm Makefile

# Build specific version of vllm
RUN make build-vllm-rocm

# Build Flash Attention v2 kernels
FROM kernel-builder AS flash-att-v2-builder
WORKDIR /usr/src

COPY server/Makefile-flash-att-v2 Makefile

# Build specific version of flash attention v2
RUN make build-flash-attention-v2-rocm

# Build Transformers CUDA kernels (gpt-neox and bloom)
250
FROM kernel-builder AS custom-kernels-builder
fxmarty's avatar
fxmarty committed
251
252
WORKDIR /usr/src
COPY server/custom_kernels/ .
fxmarty's avatar
fxmarty committed
253
RUN python setup.py build
fxmarty's avatar
fxmarty committed
254

fxmarty's avatar
fxmarty committed
255
# Build exllama kernels
256
FROM kernel-builder AS exllama-kernels-builder
fxmarty's avatar
fxmarty committed
257
258
259
WORKDIR /usr/src
COPY server/exllama_kernels/ .

fxmarty's avatar
fxmarty committed
260
RUN python setup.py build
fxmarty's avatar
fxmarty committed
261
262

# Build exllama v2 kernels
263
FROM kernel-builder AS exllamav2-kernels-builder
fxmarty's avatar
fxmarty committed
264
265
266
WORKDIR /usr/src
COPY server/exllamav2_kernels/ .

fxmarty's avatar
fxmarty committed
267
RUN python setup.py build
fxmarty's avatar
fxmarty committed
268

269
FROM install_deps AS base-copy
fxmarty's avatar
fxmarty committed
270
271

# Text Generation Inference base env
272
ENV HF_HOME=/data \
fxmarty's avatar
fxmarty committed
273
274
275
276
    HF_HUB_ENABLE_HF_TRANSFER=1 \
    PORT=80

# Copy builds artifacts from vllm builder
Nicolas Patry's avatar
Nicolas Patry committed
277
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
fxmarty's avatar
fxmarty committed
278
279

# Copy build artifacts from flash attention v2 builder
Nicolas Patry's avatar
Nicolas Patry committed
280
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
fxmarty's avatar
fxmarty committed
281
282

# Copy build artifacts from custom kernels builder
Nicolas Patry's avatar
Nicolas Patry committed
283
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
fxmarty's avatar
fxmarty committed
284
285

# Copy build artifacts from exllama kernels builder
Nicolas Patry's avatar
Nicolas Patry committed
286
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
fxmarty's avatar
fxmarty committed
287
288

# Copy build artifacts from exllamav2 kernels builder
Nicolas Patry's avatar
Nicolas Patry committed
289
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
fxmarty's avatar
fxmarty committed
290
291
292
293
294
295
296
297

# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
    make gen-server && \
    pip install -r requirements_rocm.txt && \
OlivierDehaene's avatar
OlivierDehaene committed
298
    pip install ".[accelerate, peft, outlines]" --no-cache-dir
fxmarty's avatar
fxmarty committed
299
300

# Install benchmarker
301
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
fxmarty's avatar
fxmarty committed
302
# Install router
303
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
fxmarty's avatar
fxmarty committed
304
# Install launcher
305
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
Nicolas Patry's avatar
Nicolas Patry committed
306
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
fxmarty's avatar
fxmarty committed
307
308

# AWS Sagemaker compatible image
309
FROM base AS sagemaker
fxmarty's avatar
fxmarty committed
310

fxmarty's avatar
fxmarty committed
311
312
313
314
315
316
317
318
COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

ENTRYPOINT ["./entrypoint.sh"]

# Final image
FROM base-copy

319
320
321
322
323
324
325
326
327
328
329
330
331
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1

# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
# However, Triton requires a tunning for each prompt length, which is prohibitive.
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV ROCM_USE_SKINNY_GEMM=1

fxmarty's avatar
fxmarty committed
332
333
334
335
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

ENTRYPOINT ["/tgi-entrypoint.sh"]
fxmarty's avatar
fxmarty committed
336
CMD ["--json-output"]