Dockerfile.rocm 7.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Default ROCm 6.1 base image
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"

# Tested and supported base rocm/pytorch images
ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \
    ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \
    ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"

# Default ROCm ARCHes to build vLLM for.
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"

# Whether to build CK-based flash-attention
# If 0, will not build flash attention
# This is useful for gfx target where flash-attention is not supported
# (i.e. those that do not appear in `FA_GFX_ARCHS`)
# Triton FA is used by default on ROCm now so this is unnecessary.
ARG BUILD_FA="1"
18
ARG FA_GFX_ARCHS="gfx90a;gfx942"
19
ARG FA_BRANCH="ae7928c"
20

21
# Whether to build triton on rocm
22
ARG BUILD_TRITON="1"
23
ARG TRITON_BRANCH="0ef1848"
24

25
26
27
28
29
### Base image build stage
FROM $BASE_IMAGE AS base

# Import arg(s) defined before this build stage
ARG PYTORCH_ROCM_ARCH
30
31

# Install some basic utilities
32
RUN apt-get update && apt-get install python3 python3-pip -y
33
34
35
36
37
38
39
40
41
42
43
RUN apt-get update && apt-get install -y \
    curl \
    ca-certificates \
    sudo \
    git \
    bzip2 \
    libx11-6 \
    build-essential \
    wget \
    unzip \
    tmux \
Simon Mo's avatar
Simon Mo committed
44
    ccache \
45
46
 && rm -rf /var/lib/apt/lists/*

47
# When launching the container, mount the code directory to /vllm-workspace
48
ARG APP_MOUNT=/vllm-workspace
49
50
WORKDIR ${APP_MOUNT}

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
RUN pip install --upgrade pip
# Remove sccache so it doesn't interfere with ccache
# TODO: implement sccache support across components
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
# Install torch == 2.4.0 on ROCm
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
        *"rocm-5.7"*) \
            pip uninstall -y torch \
            && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
               --index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \
        *"rocm-6.0"*) \
            pip uninstall -y torch \
            && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
               --index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \
        *"rocm-6.1"*) \
            pip uninstall -y torch \
            && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
               --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
        *) ;; esac
70
71
72
73
74
75

ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
ENV CCACHE_DIR=/root/.cache/ccache


### AMD-SMI build stage
FROM base AS build_amdsmi
# Build amdsmi wheel always
RUN cd /opt/rocm/share/amd_smi \
    && pip wheel . --wheel-dir=/install


### Flash-Attention wheel build stage
FROM base AS build_fa
ARG BUILD_FA
ARG FA_GFX_ARCHS
ARG FA_BRANCH
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
RUN --mount=type=cache,target=${CCACHE_DIR} \
    if [ "$BUILD_FA" = "1" ]; then \
    mkdir -p libs \
96
    && cd libs \
97
    && git clone https://github.com/ROCm/flash-attention.git \
98
    && cd flash-attention \
99
    && git checkout "${FA_BRANCH}" \
100
    && git submodule update --init \
101
102
103
104
105
106
107
108
    && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
        *"rocm-5.7"*) \
            export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \
            && patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \
        *) ;; esac \
    && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
    # Create an empty directory otherwise as later build stages expect one
    else mkdir -p /install; \
109
    fi
110
111


112
113
114
115
116
117
118
### Triton wheel build stage
FROM base AS build_triton
ARG BUILD_TRITON
ARG TRITON_BRANCH
# Build triton wheel if `BUILD_TRITON = 1`
RUN --mount=type=cache,target=${CCACHE_DIR} \
    if [ "$BUILD_TRITON" = "1" ]; then \
119
120
    mkdir -p libs \
    && cd libs \
121
122
123
124
125
126
127
    && git clone https://github.com/OpenAI/triton.git \
    && cd triton \
    && git checkout "${TRITON_BRANCH}" \
    && cd python \
    && python3 setup.py bdist_wheel --dist-dir=/install; \
    # Create an empty directory otherwise as later build stages expect one
    else mkdir -p /install; \
128
129
    fi

130
131
132
133

### Final vLLM build stage
FROM base AS final
# Import the vLLM development directory from the build context
134
COPY . .
135

136
137
138
139
140
141
142
143
144
145
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually remove it so that later steps of numpy upgrade can continue
RUN case "$(which python3)" in \
        *"/opt/conda/envs/py_3.9"*) \
            rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
        *) ;; esac

# Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \
    pip install --upgrade numba scipy huggingface-hub[cli]
146

147
# Make sure punica kernels are built (for LoRA)
148
ENV VLLM_INSTALL_PUNICA_KERNELS=1
149
150
# Workaround for ray >= 2.10.0
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
151
152
# Silences the HF Tokenizers warning
ENV TOKENIZERS_PARALLELISM=false
153

154
RUN --mount=type=cache,target=${CCACHE_DIR} \
Simon Mo's avatar
Simon Mo committed
155
    --mount=type=cache,target=/root/.cache/pip \
156
    pip install -U -r requirements-rocm.txt \
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
        *"rocm-6.0"*) \
            patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \
        *"rocm-6.1"*) \
            # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
            wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \
            && cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \
            # Prevent interference if torch bundles its own HIP runtime
            && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
        *) ;; esac \
    && python3 setup.py clean --all \
    && python3 setup.py develop

# Copy amdsmi wheel into final image
RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
    mkdir -p libs \
    && cp /install/*.whl libs \
    # Preemptively uninstall to avoid same-version no-installs
    && pip uninstall -y amdsmi;
176

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# Copy triton wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
    mkdir -p libs \
    && if ls /install/*.whl; then \
        cp /install/*.whl libs \
        # Preemptively uninstall to avoid same-version no-installs
        && pip uninstall -y triton; fi

# Copy flash-attn wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
    mkdir -p libs \
    && if ls /install/*.whl; then \
        cp /install/*.whl libs \
        # Preemptively uninstall to avoid same-version no-installs
        && pip uninstall -y flash-attn; fi

# Install wheels that were built to the final image
RUN --mount=type=cache,target=/root/.cache/pip \
    if ls libs/*.whl; then \
    pip install libs/*.whl; fi
197
198

CMD ["/bin/bash"]