ARG VERSION=2.7.1-ubuntu22.04-dtk26.04-py3.10
ARG BASE_IMAGE=harbor.sourcefind.cn:5443/dcu/admin/base/pytorch:${VERSION}
FROM ${BASE_IMAGE}

ARG COS_BASE_URL="https://haihub-model-1251001002.cos.ap-shanghai.myqcloud.com/HyperDrive/haiguang/hygon-bench/train"

# Ubuntu apt source
RUN UBUNTU_VERSION=$(cat /etc/os-release | grep 'VERSION_ID' | cut -d '=' -f2 | tr -d '"') && \
echo "Ubuntu version detected: $UBUNTU_VERSION" && \
if [ "$UBUNTU_VERSION" = "22.04" ]; then \
    printf '%s\n' \
        'deb http://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy main restricted universe multiverse' \
        'deb http://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-updates main restricted universe multiverse' \
        'deb http://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-backports main restricted universe multiverse' \
        'deb http://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-security main restricted universe multiverse' \
        > /etc/apt/sources.list; \
elif [ "$UBUNTU_VERSION" = "24.04" ]; then \
    wget -O /etc/apt/sources.list.d/ubuntu.sources https://haihub-model-1251001002.cos.accelerate.myqcloud.com/TCCL/sources_2404.list; \
else \
    echo "Unsupported Ubuntu version: $UBUNTU_VERSION"; exit 1; \
fi

RUN apt update && apt install -y \
        cmake net-tools pdsh tmux vim iputils-ping libnuma-dev libcap2 lrzsz curl python3-tk \
        aria2 zstd pv

# Configure timezone
RUN ln -fs /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \
    apt update && apt install -y tzdata && \
        dpkg-reconfigure --frontend noninteractive tzdata

# Support Chinese language
RUN apt install -y language-pack-zh-hans
ENV LANG="zh_CN.UTF-8"
ENV LANGUAGE="zh_CN:zh:en_US:en"

# Allow OpenSSH to talk to containers without asking for confirmation
RUN mkdir -p /run/sshd && \
    mkdir -p /root/.ssh/ && \
    ssh-keygen -A && \
    ssh-keygen -t rsa -f /etc/ssh/ssh_host_rsa_key -y && \
    ssh-keygen -t ecdsa -f /etc/ssh/ssh_host_ecdsa_key -y && \
    ssh-keygen -t ed25519 -f /etc/ssh/ssh_host_ed25519_key -y && \
    ssh-keygen -t rsa -C "" -f ~/.ssh/id_rsa -P "" && \
    cat /root/.ssh/id_rsa.pub >> /root/.ssh/authorized_keys && \
    chmod 0600 /root/.ssh/authorized_keys && \
    printf "Host * \n    ForwardAgent yes \nHost * \n    StrictHostKeyChecking no \nPort 3333" > /root/.ssh/config && \
    sed -i 's/#Port 22/Port 3333/g' /etc/ssh/sshd_config

# Tsinghua pip source
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip config set global.trusted-host pypi.tuna.tsinghua.edu.cn

RUN mkdir -p /workspace

# Install OFED UMD
RUN UBUNTU_VERSION=$(cat /etc/os-release | grep 'VERSION_ID' | cut -d '=' -f2 | tr -d '"') && \
    echo "Ubuntu version detected: $UBUNTU_VERSION" && \
    OFED_PKG="MLNX_OFED_LINUX-5.8-2.0.3.0-ubuntu22.04-x86_64" && \
    OFED_URL="https://haihub-model-1251001002.cos.accelerate.myqcloud.com/TCCL/${OFED_PKG}.tgz"; \
    if [ "$UBUNTU_VERSION" = "22.04" ]; then \
        EXTRA_OPTS=""; \
        NEED_SED="false"; \
    elif [ "$UBUNTU_VERSION" = "24.04" ]; then \
        EXTRA_OPTS="--skip-distro-check"; \
        NEED_SED="true"; \
    else \
        echo "Unsupported Ubuntu version: $UBUNTU_VERSION"; exit 1; \
    fi && \
    aria2c --enable-color=false --show-console-readout=false -s 8 -x 8 "$OFED_URL" && \
    tar xf ${OFED_PKG}.tgz && cd ${OFED_PKG} && \
    if [ "$NEED_SED" = "true" ]; then \
        sed -i 's/dpatch//g' mlnxofedinstall && \
        sed -i 's/ubuntu22/ubuntu2[24]/g' mlnxofedinstall; \
    fi && \
    ./mlnxofedinstall --user-space-only --without-fw-update --without-ucx-cuda --force $EXTRA_OPTS && \
    cd ../ && rm -rf MLNX_OFED_LINUX*

# Download Hyhal for compile if the base image does not already provide it.
RUN if [ -e /opt/hyhal/lib ]; then \
        echo "Use existing /opt/hyhal"; \
    else \
        aria2c --enable-color=false --show-console-readout=false -x 8 -s 8 -k 1M -d /tmp -o hyhal.tar.gz ${COS_BASE_URL}/common/hyhal.tar.gz && \
        mkdir -p /opt/hyhal && \
        tar -zxf /tmp/hyhal.tar.gz --strip-components=2 -C /opt/hyhal && \
        rm -rf /tmp/hyhal.tar.gz; \
    fi

# install rccl-tests
COPY 3rdparty/rccl-tests /workspace/rccl-tests
RUN cd /workspace/rccl-tests && \
    LD_LIBRARY_PATH=/opt/hyhal/lib:$LD_LIBRARY_PATH make ROCM_HOME=/opt/dtk NCCL_HOME=/opt/dtk/rccl CUSTOM_RCCL_LIB=/opt/dtk/rccl/lib/librccl.so MPI=1 MPI_HOME=/usr/mpi/gcc/openmpi-4.1.5a1 && \
    cp build/*perf /usr/local/bin/ && \
    cd .. && rm -rf /workspace/rccl-tests

# Remove hyhal for runtime
RUN rm -rf /opt/hyhal

# install scripts
COPY scripts/run_rccl_test.sh /workspace/run_rccl_test.sh

# workaround for rccl perf
RUN mkdir -p /opt/dtk/rccl/patch/
COPY misc/fix_rccl/fix_graph.xml /opt/dtk/rccl/patch/fix_graph.xml
COPY misc/fix_rccl/fix_topo.xml /opt/dtk/rccl/patch/fix_topo.xml
RUN mkdir -p /workspace/scripts
COPY scripts/bssh.sh /workspace/scripts/bssh.sh
COPY scripts/bscp.sh /workspace/scripts/bscp.sh
COPY scripts/cluster_check.sh /workspace/scripts/cluster_check.sh
COPY scripts/get_rdma_order_by_ip.sh /workspace/scripts/get_rdma_order_by_ip.sh
COPY scripts/rdma_bw_check.py /workspace/scripts/rdma_bw_check.py
COPY scripts/rdma_monitor /workspace/scripts/rdma_monitor
COPY scripts/rdma_nic_check.py /workspace/scripts/rdma_nic_check.py

# nccl settings
ENV NCCL_SOCKET_IFNAME=eth0
ENV NCCL_IB_GID_INDEX=3
ENV NCCL_IB_DISABLE=0
ENV NCCL_NET_GDR_LEVEL=2
ENV NCCL_IB_QPS_PER_CONNECTION=4
ENV NCCL_IB_TC=160
ENV NCCL_IB_TIMEOUT=22
ENV GLOO_SOCKET_IFNAME=eth0

#ENV NCCL_GRAPH_FILE=/opt/dtk/rccl/patch/fix_graph.xml
ENV NCCL_TOPO_FILE=/opt/dtk/rccl/patch/fix_topo.xml
ENV NCCL_ROCE_SRC_PORT_LIST=60000,60051,57663,57804
ENV RCCL_MODEL_MATCHING_DISABLE=1
ENV HSA_FORCE_FINE_GRAIN_PCIE=1
ENV NCCL_PXN_DISABLE=1
ENV NCCL_ALGO=Ring

# workaround for the issue that env get lost over ssh login
RUN env > /etc/environment && \
    printf "\nexport LD_LIBRARY_PATH=\"/opt/hyhal/lib:\$LD_LIBRARY_PATH\"\n\n" | tee -a /root/.bashrc /root/.hygon_base_env

RUN printf '%s\n' \
    "FASTPT_TORCH_LIB_PATH=/usr/local/lib/python3.10/dist-packages/torch/lib" \
    "FASTPT_CUDA_LIB_PATH=/opt/dtk/cuda/cuda-12/lib64/" "" \
    "export LD_LIBRARY_PATH=\"\$FASTPT_TORCH_LIB_PATH:\$LD_LIBRARY_PATH\"" \
    "export LD_LIBRARY_PATH=\"\$FASTPT_CUDA_LIB_PATH:\$LD_LIBRARY_PATH\"" "" | tee -a /root/.bashrc /root/.hygon_base_env

# ulimit
RUN echo 'root soft nofile 4194304' | tee -a /etc/security/limits.conf && echo 'root hard nofile 4194304' | tee -a /etc/security/limits.conf

WORKDIR "/workspace/"

ENTRYPOINT service ssh restart && bash
