Dockerfile.rocm_base 7.35 KB
Newer Older
1
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete
2
ARG HIPBLASLT_BRANCH="db8e93b4"
3
4
5
6
7
8
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
ARG LEGACY_HIPBLASLT_OPTION=
ARG RCCL_BRANCH="648a58d"
ARG RCCL_REPO="https://github.com/ROCm/rccl"
ARG TRITON_BRANCH="e5be006"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
9
10
ARG PYTORCH_BRANCH="295f2ed4"
ARG PYTORCH_VISION_BRANCH="v0.21.0"
11
12
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
13
14
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
15
ARG AITER_BRANCH="6487649"
16
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
17
18
19
20
21
22

FROM ${BASE_IMAGE} AS base

ENV PATH=/opt/rocm/llvm/bin:$PATH
ENV ROCM_PATH=/opt/rocm
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
23
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201
24
25
26
27
28
29
30
31
32
33
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}

ARG PYTHON_VERSION=3.12

RUN mkdir -p /app
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive

# Install Python and other dependencies
RUN apt-get update -y \
34
    && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 \
35
36
37
38
    && for i in 1 2 3; do \
        add-apt-repository -y ppa:deadsnakes/ppa && break || \
        { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
    done \
39
40
41
42
43
44
45
46
47
    && apt-get update -y \
    && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
       python${PYTHON_VERSION}-lib2to3 python-is-python3  \
    && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
    && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
    && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
    && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
    && python3 --version && python3 -m pip --version

48
RUN pip install -U packaging 'cmake<4' ninja wheel setuptools pybind11 Cython
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

FROM base AS build_hipblaslt
ARG HIPBLASLT_BRANCH
ARG HIPBLAS_COMMON_BRANCH
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
ARG LEGACY_HIPBLASLT_OPTION
RUN git clone https://github.com/ROCm/hipBLAS-common.git
RUN cd hipBLAS-common \
    && git checkout ${HIPBLAS_COMMON_BRANCH} \
    && mkdir build \
    && cd build \
    && cmake .. \
    && make package \
    && dpkg -i ./*.deb
RUN git clone https://github.com/ROCm/hipBLASLt
RUN cd hipBLASLt \
    && git checkout ${HIPBLASLT_BRANCH} \
66
67
    && apt-get install -y llvm-dev \
    && ./install.sh -dc --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    && cd build/release \
    && make package
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install

FROM base AS build_rccl
ARG RCCL_BRANCH
ARG RCCL_REPO
RUN git clone ${RCCL_REPO}
RUN cd rccl \
    && git checkout ${RCCL_BRANCH} \
    && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install

FROM base AS build_triton
ARG TRITON_BRANCH
ARG TRITON_REPO
RUN git clone ${TRITON_REPO}
RUN cd triton \
    && git checkout ${TRITON_BRANCH} \
    && cd python \
    && python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install

FROM base AS build_amdsmi
RUN cd /opt/rocm/share/amd_smi \
    && pip wheel . --wheel-dir=dist
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install

FROM base AS build_pytorch
ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
ARG FA_BRANCH
ARG FA_REPO
RUN git clone ${PYTORCH_REPO} pytorch
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \
    pip install -r requirements.txt && git submodule update --init --recursive \
    && python3 tools/amd_build/build_amd.py \
    && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
    && pip install dist/*.whl
RUN git clone ${PYTORCH_VISION_REPO} vision
RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
    && python3 setup.py bdist_wheel --dist-dir=dist \
    && pip install dist/*.whl
RUN git clone ${FA_REPO}
RUN cd flash-attention \
    && git checkout ${FA_BRANCH} \
    && git submodule update --init \
117
    && GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist
118
119
120
121
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
    && cp /app/vision/dist/*.whl /app/install \
    && cp /app/flash-attention/dist/*.whl /app/install

122
123
124
125
126
127
128
129
130
131
132
133
134
FROM base AS build_aiter
ARG AITER_BRANCH
ARG AITER_REPO
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
    pip install /install/*.whl
RUN git clone --recursive ${AITER_REPO}
RUN cd aiter \
    && git checkout ${AITER_BRANCH} \
    && git submodule update --init --recursive \
    && pip install -r requirements.txt
RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
FROM base AS final
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
    dpkg -i /install/*deb \
    && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
    && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
    dpkg -i /install/*deb \
    && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
    && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
    pip install /install/*.whl
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
    pip install /install/*.whl
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
    pip install /install/*.whl
150
151
RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
    pip install /install/*.whl
152

153
ARG BASE_IMAGE
154
ARG HIPBLAS_COMMON_BRANCH
155
ARG HIPBLASLT_BRANCH
156
157
158
159
160
161
162
163
164
165
166
ARG LEGACY_HIPBLASLT_OPTION
ARG RCCL_BRANCH
ARG RCCL_REPO
ARG TRITON_BRANCH
ARG TRITON_REPO
ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
ARG FA_BRANCH
ARG FA_REPO
167
168
ARG AITER_BRANCH
ARG AITER_REPO
169
170
171
172
173
174
175
176
177
178
179
180
181
RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
    && echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \
    && echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \
    && echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \
    && echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \
    && echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \
    && echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \
    && echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \
    && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
    && echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \
    && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
    && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
    && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
182
183
    && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
    && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt