Dockerfile.rocm_base 5.98 KB
Newer Older
1
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.1-complete
2
ARG TRITON_BRANCH="57c693b6"
3
ARG TRITON_REPO="https://github.com/ROCm/triton.git"
4
ARG PYTORCH_BRANCH="1c57644d"
5
ARG PYTORCH_VISION_BRANCH="v0.23.0"
6
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
7
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
8
ARG FA_BRANCH="0e60e394"
9
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
10
ARG AITER_BRANCH="59bd8ff2"
11
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
12
13
14

FROM ${BASE_IMAGE} AS base

15
ENV PATH=/opt/rocm/llvm/bin:/opt/rocm/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
16
17
ENV ROCM_PATH=/opt/rocm
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
18
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151
19
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
20
ENV AITER_ROCM_ARCH=gfx942;gfx950
21

22
23
24
# Required for RCCL in ROCm7.1
ENV HSA_NO_SCRATCH_RECLAIM=1

25
26
27
28
29
30
31
32
ARG PYTHON_VERSION=3.12

RUN mkdir -p /app
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive

# Install Python and other dependencies
RUN apt-get update -y \
33
    && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 \
34
35
36
37
    && for i in 1 2 3; do \
        add-apt-repository -y ppa:deadsnakes/ppa && break || \
        { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
    done \
38
39
40
41
42
43
44
45
46
    && apt-get update -y \
    && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
       python${PYTHON_VERSION}-lib2to3 python-is-python3  \
    && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
    && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
    && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
    && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
    && python3 --version && python3 -m pip --version

47
RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython
48
49
50
51
52
53
54

FROM base AS build_triton
ARG TRITON_BRANCH
ARG TRITON_REPO
RUN git clone ${TRITON_REPO}
RUN cd triton \
    && git checkout ${TRITON_BRANCH} \
55
56
57
58
59
    && if [ ! -f setup.py ]; then cd python; fi \
    && python3 setup.py bdist_wheel --dist-dir=dist \
    && mkdir -p /app/install && cp dist/*.whl /app/install
RUN if [ -d triton/python/triton_kernels ]; then pip install build && cd triton/python/triton_kernels \
    && python3 -m build --wheel && cp dist/*.whl /app/install; fi
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

FROM base AS build_amdsmi
RUN cd /opt/rocm/share/amd_smi \
    && pip wheel . --wheel-dir=dist
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install

FROM base AS build_pytorch
ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
RUN git clone ${PYTORCH_REPO} pytorch
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \
    pip install -r requirements.txt && git submodule update --init --recursive \
    && python3 tools/amd_build/build_amd.py \
    && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
    && pip install dist/*.whl
RUN git clone ${PYTORCH_VISION_REPO} vision
RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
    && python3 setup.py bdist_wheel --dist-dir=dist \
    && pip install dist/*.whl
81
82
83
84
85
86
87
88
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
    && cp /app/vision/dist/*.whl /app/install

FROM base AS build_fa
ARG FA_BRANCH
ARG FA_REPO
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
    pip install /install/*.whl
89
90
91
92
RUN git clone ${FA_REPO}
RUN cd flash-attention \
    && git checkout ${FA_BRANCH} \
    && git submodule update --init \
93
    && GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist
94
RUN mkdir -p /app/install && cp /app/flash-attention/dist/*.whl /app/install
95

96
97
98
99
100
101
102
103
104
105
FROM base AS build_aiter
ARG AITER_BRANCH
ARG AITER_REPO
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
    pip install /install/*.whl
RUN git clone --recursive ${AITER_REPO}
RUN cd aiter \
    && git checkout ${AITER_BRANCH} \
    && git submodule update --init --recursive \
    && pip install -r requirements.txt
106
RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
107
108
RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install

109
110
111
112
FROM base AS debs
RUN mkdir /app/debs
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
    cp /install/*.whl /app/debs
113
114
RUN --mount=type=bind,from=build_fa,src=/app/install/,target=/install \
    cp /install/*.whl /app/debs
115
116
117
118
119
120
121
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
    cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
    cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
    cp /install/*.whl /app/debs

122
FROM base AS final
123
RUN --mount=type=bind,from=debs,src=/app/debs,target=/install \
124
    pip install /install/*.whl
125

126
127
128
129
130
131
132
133
134
ARG BASE_IMAGE
ARG TRITON_BRANCH
ARG TRITON_REPO
ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
ARG FA_BRANCH
ARG FA_REPO
135
136
ARG AITER_BRANCH
ARG AITER_REPO
137
138
139
140
141
142
143
144
RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
    && echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \
    && echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \
    && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
    && echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \
    && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
    && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
    && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
145
    && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
146
    && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
147
    && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt