Dockerfile.rocm_base 10.8 KB
Newer Older
1
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete
2
ARG TRITON_BRANCH="57c693b6"
3
ARG TRITON_REPO="https://github.com/ROCm/triton.git"
4
ARG PYTORCH_BRANCH="89075173"
5
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
6
ARG PYTORCH_VISION_BRANCH="v0.24.1"
7
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
8
9
ARG PYTORCH_AUDIO_BRANCH="v2.9.0"
ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git"
10
ARG FA_BRANCH="0e60e394"
11
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
12
ARG AITER_BRANCH="6af8b687"
13
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
14
15
ARG MORI_BRANCH="2d02c6a9"
ARG MORI_REPO="https://github.com/ROCm/mori.git"
16

17
18
19
20
21
22
23
24
25
26
27
#TODO: When patch has been upstreamed, switch to the main repo/branch
# ARG RIXL_BRANCH="<TODO>"
# ARG RIXL_REPO="https://github.com/ROCm/RIXL.git"
ARG RIXL_BRANCH="50d63d94"
ARG RIXL_REPO="https://github.com/vcave/RIXL.git"
# Needed by RIXL
ARG ETCD_BRANCH="7c6e714f"
ARG ETCD_REPO="https://github.com/etcd-cpp-apiv3/etcd-cpp-apiv3.git"
ARG UCX_BRANCH="da3fac2a"
ARG UCX_REPO="https://github.com/ROCm/ucx.git"

28
29
FROM ${BASE_IMAGE} AS base

30
ENV PATH=/opt/rocm/llvm/bin:/opt/rocm/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
31
32
ENV ROCM_PATH=/opt/rocm
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
33
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151
34
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
35
ENV AITER_ROCM_ARCH=gfx942;gfx950
36
ENV MORI_GPU_ARCHS=gfx942;gfx950
37

38
39
40
# Required for RCCL in ROCm7.1
ENV HSA_NO_SCRATCH_RECLAIM=1

41
ARG PYTHON_VERSION=3.12
42
ENV PYTHON_VERSION=${PYTHON_VERSION}
43
44
45
46
47
48
49

RUN mkdir -p /app
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive

# Install Python and other dependencies
RUN apt-get update -y \
50
    && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev \
51
52
53
54
    && for i in 1 2 3; do \
        add-apt-repository -y ppa:deadsnakes/ppa && break || \
        { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
    done \
55
56
57
58
59
60
61
62
63
    && apt-get update -y \
    && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
       python${PYTHON_VERSION}-lib2to3 python-is-python3  \
    && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
    && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
    && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
    && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
    && python3 --version && python3 -m pip --version

64
RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython
65
RUN apt-get update && apt-get install -y libjpeg-dev libsox-dev libsox-fmt-all sox && rm -rf /var/lib/apt/lists/*
66

67
68
69
70

###
### Triton Build
###
71
72
73
74
75
76
FROM base AS build_triton
ARG TRITON_BRANCH
ARG TRITON_REPO
RUN git clone ${TRITON_REPO}
RUN cd triton \
    && git checkout ${TRITON_BRANCH} \
77
78
79
80
81
    && if [ ! -f setup.py ]; then cd python; fi \
    && python3 setup.py bdist_wheel --dist-dir=dist \
    && mkdir -p /app/install && cp dist/*.whl /app/install
RUN if [ -d triton/python/triton_kernels ]; then pip install build && cd triton/python/triton_kernels \
    && python3 -m build --wheel && cp dist/*.whl /app/install; fi
82

83
84
85
86

###
### AMD SMI Build
###
87
88
89
90
91
FROM base AS build_amdsmi
RUN cd /opt/rocm/share/amd_smi \
    && pip wheel . --wheel-dir=dist
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install

92
93
94
95

###
### Pytorch build
###
96
97
98
FROM base AS build_pytorch
ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
99
ARG PYTORCH_AUDIO_BRANCH
100
101
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
102
103
ARG PYTORCH_AUDIO_REPO

104
RUN git clone ${PYTORCH_REPO} pytorch
105
106
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} \
    && pip install -r requirements.txt && git submodule update --init --recursive \
107
108
109
110
111
112
113
    && python3 tools/amd_build/build_amd.py \
    && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
    && pip install dist/*.whl
RUN git clone ${PYTORCH_VISION_REPO} vision
RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
    && python3 setup.py bdist_wheel --dist-dir=dist \
    && pip install dist/*.whl
114
115
116
117
118
119
RUN git clone ${PYTORCH_AUDIO_REPO} audio
RUN cd audio && git checkout ${PYTORCH_AUDIO_BRANCH} \
    && git submodule update --init --recursive \
    && pip install -r requirements.txt \
    && python3 setup.py bdist_wheel --dist-dir=dist \
    && pip install dist/*.whl
120
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
121
122
    && cp /app/vision/dist/*.whl /app/install \
    && cp /app/audio/dist/*.whl /app/install
123

124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
###
### MORI Build
###
FROM base AS build_mori
ARG MORI_BRANCH
ARG MORI_REPO
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
    pip install /install/*.whl
RUN git clone ${MORI_REPO}
RUN cd mori \
    && git checkout ${MORI_BRANCH} \
    && git submodule update --init --recursive \
    && python3 setup.py bdist_wheel --dist-dir=dist && ls /app/mori/dist/*.whl
RUN mkdir -p /app/install && cp /app/mori/dist/*.whl /app/install


141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
###
### RIXL Build
###
FROM build_pytorch AS build_rixl
ARG RIXL_BRANCH
ARG RIXL_REPO
ARG ETCD_BRANCH
ARG ETCD_REPO
ARG UCX_BRANCH
ARG UCX_REPO

ENV ROCM_PATH=/opt/rocm
ENV UCX_HOME=/usr/local/ucx
ENV RIXL_HOME=/usr/local/rixl
ENV RIXL_BENCH_HOME=/usr/local/rixl_bench

# RIXL build system dependences and RDMA support
RUN apt-get -y update && apt-get -y install autoconf libtool pkg-config \
    libgrpc-dev \
    libgrpc++-dev \
    libprotobuf-dev \
    protobuf-compiler-grpc \
    libcpprest-dev \
    libaio-dev \
    librdmacm1 \
    librdmacm-dev \
    libibverbs1 \
    libibverbs-dev \
    ibverbs-utils \
    rdmacm-utils \
    ibverbs-providers

RUN pip install meson auditwheel patchelf tomlkit

WORKDIR /workspace

RUN git clone ${ETCD_REPO} && \
    cd etcd-cpp-apiv3 && \
    git checkout ${ETCD_BRANCH} && \
    mkdir build && cd build && \
    cmake .. -DCMAKE_POLICY_VERSION_MINIMUM=3.5 && \
    make -j$(nproc) && \
    make install

RUN cd /usr/local/src && \
    git clone ${UCX_REPO} &&  \
    cd ucx  && \
    git checkout ${UCX_BRANCH} && \
    ./autogen.sh && \
    mkdir build && cd build && \
    ../configure \
        --prefix=/usr/local/ucx \
        --enable-shared \
        --disable-static \
        --disable-doxygen-doc \
        --enable-optimizations \
        --enable-devel-headers \
        --with-rocm=/opt/rocm \
        --with-verbs \
        --with-dm \
        --enable-mt && \
    make -j && \
    make -j install

ENV PATH=/usr/local/ucx/bin:$PATH
ENV LD_LIBRARY_PATH=${UCX_HOME}/lib:${LD_LIBRARY_PATH}

RUN git clone ${RIXL_REPO} /opt/rixl && \
    cd /opt/rixl && \
    git checkout ${RIXL_BRANCH} && \
    meson setup build --prefix=${RIXL_HOME} \
                     -Ducx_path=${UCX_HOME} \
                     -Drocm_path=${ROCM_PATH} && \
    cd build && \
    ninja && \
    ninja install

# Generate RIXL wheel
RUN cd /opt/rixl && mkdir -p /app/install && \
    ./contrib/build-wheel.sh \
        --output-dir /app/install \
        --rocm-dir ${ROCM_PATH} \
        --ucx-plugins-dir ${UCX_HOME}/lib/ucx \
        --nixl-plugins-dir ${RIXL_HOME}/lib/x86_64-linux-gnu/plugins


###
### FlashAttention Build
###
230
231
232
233
234
FROM base AS build_fa
ARG FA_BRANCH
ARG FA_REPO
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
    pip install /install/*.whl
235
236
237
238
RUN git clone ${FA_REPO}
RUN cd flash-attention \
    && git checkout ${FA_BRANCH} \
    && git submodule update --init \
239
    && GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist
240
RUN mkdir -p /app/install && cp /app/flash-attention/dist/*.whl /app/install
241

242
243
244
245

###
### AITER Build
###
246
247
248
249
250
251
252
253
254
255
FROM base AS build_aiter
ARG AITER_BRANCH
ARG AITER_REPO
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
    pip install /install/*.whl
RUN git clone --recursive ${AITER_REPO}
RUN cd aiter \
    && git checkout ${AITER_BRANCH} \
    && git submodule update --init --recursive \
    && pip install -r requirements.txt
256
RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
257
258
RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install

259
260
261
262

###
### Final Build
###
263
264
265
266
FROM base AS debs
RUN mkdir /app/debs
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
    cp /install/*.whl /app/debs
267
268
RUN --mount=type=bind,from=build_fa,src=/app/install/,target=/install \
    cp /install/*.whl /app/debs
269
270
271
272
273
274
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
    cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
    cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
    cp /install/*.whl /app/debs
275
276
RUN --mount=type=bind,from=build_mori,src=/app/install/,target=/install \
    cp /install/*.whl /app/debs
277
278
RUN --mount=type=bind,from=build_rixl,src=/app/install/,target=/install \
    cp /install/*.whl /app/debs
279

280
FROM base AS final
281
RUN --mount=type=bind,from=debs,src=/app/debs,target=/install \
282
    pip install /install/*.whl
283

284
285
286
287
288
289
290
ARG BASE_IMAGE
ARG TRITON_BRANCH
ARG TRITON_REPO
ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
291
292
ARG PYTORCH_AUDIO_BRANCH
ARG PYTORCH_AUDIO_REPO
293
294
ARG FA_BRANCH
ARG FA_REPO
295
296
ARG AITER_BRANCH
ARG AITER_REPO
297
298
299
300
301
302
ARG RIXL_BRANCH
ARG RIXL_REPO
ARG ETCD_BRANCH
ARG ETCD_REPO
ARG UCX_BRANCH
ARG UCX_REPO
303
304
ARG MORI_BRANCH
ARG MORI_REPO
305
306
307
308
309
310
311
RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
    && echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \
    && echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \
    && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
    && echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \
    && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
    && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
312
313
    && echo "PYTORCH_AUDIO_BRANCH: ${PYTORCH_AUDIO_BRANCH}" >> /app/versions.txt \
    && echo "PYTORCH_AUDIO_REPO: ${PYTORCH_AUDIO_REPO}" >> /app/versions.txt \
314
    && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
315
    && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
316
    && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
317
318
319
320
321
322
    && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt \
    && echo "RIXL_BRANCH: ${RIXL_BRANCH}" >> /app/versions.txt \
    && echo "RIXL_REPO: ${RIXL_REPO}" >> /app/versions.txt \
    && echo "ETCD_BRANCH: ${ETCD_BRANCH}" >> /app/versions.txt \
    && echo "ETCD_REPO: ${ETCD_REPO}" >> /app/versions.txt \
    && echo "UCX_BRANCH: ${UCX_BRANCH}" >> /app/versions.txt \
323
324
325
    && echo "UCX_REPO: ${UCX_REPO}" >> /app/versions.txt \
    && echo "MORI_BRANCH: ${MORI_BRANCH}" >> /app/versions.txt \
    && echo "MORI_REPO: ${MORI_REPO}" >> /app/versions.txt