Commit 447238de authored by lijian6's avatar lijian6
Browse files

Adjust the file structure.


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent c1d9c169
...@@ -8,12 +8,12 @@ fi ...@@ -8,12 +8,12 @@ fi
PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])") PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])")
PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])") PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])")
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/intranode.hip -o build_/intranode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 /opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/intranode.cu -o build_/intranode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/runtime.hip -o build_/runtime.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 /opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/runtime.cu -o build_/runtime.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/layout.cu -o build_/layout.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 /opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/layout.cu -o build_/layout.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/deep_ep.hip -o build_/deep_ep.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 /opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/deep_ep.cu -o build_/deep_ep.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/internode.hip -o build_/internode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 /opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/internode.cu -o build_/internode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o -L$(pwd)/rocshmem_dir/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,$(pwd)/rocshmem_dir/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so.5.2.25211.1469-8d6b0397 /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/lib/ -libverbs -lmlx5 hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o -L$(pwd)/rocshmem_dir/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,$(pwd)/rocshmem_dir/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5
# build whl # build whl
echo "Using Python: $(which python3)" echo "Using Python: $(which python3)"
...@@ -21,10 +21,3 @@ python3 --version ...@@ -21,10 +21,3 @@ python3 --version
python setup.py bdist_wheel python setup.py bdist_wheel
echo "✅ Build complete:" echo "✅ Build complete:"
ls -lh dist/ ls -lh dist/
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/intranode.hip -o build_/intranode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/runtime.hip -o build_/runtime.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/layout.cu -o build_/layout.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/deep_ep.hip -o build_/deep_ep.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/internode.hip -o build_/internode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o -L/work/Tmp/DeepEP/rocshmem_dir/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o aaa.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,aaa.so -Wl,-rpath,/work/Tmp/DeepEP/rocshmem_dir/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so.5.2.25211.1469-8d6b0397 /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/lib/ -libverbs -lmlx5
#pragma once #pragma once
#include "./kernels/api.cuh" #include "kernels/api.cuh"
#include "./kernels/configs.cuh" #include "kernels/configs.cuh"
#include "kernels/exception.cuh" #include "kernels/exception.cuh"
namespace deep_ep { namespace deep_ep {
......
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#pragma once
#include "kernels/api.cuh"
#include "kernels/configs.cuh"
#include "kernels/exception.cuh"
namespace deep_ep {
struct Config {
int num_sms;
int num_max_nvl_chunked_send_tokens;
int num_max_nvl_chunked_recv_tokens;
int num_max_rdma_chunked_send_tokens;
int num_max_rdma_chunked_recv_tokens;
Config(int num_sms, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens)
: num_sms(num_sms), num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),
num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),
num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
EP_HOST_ASSERT(num_sms >= 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and
num_max_nvl_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and
num_max_rdma_chunked_recv_tokens > 0);
// Ceil up RDMA buffer size
this->num_max_rdma_chunked_recv_tokens =
ALIGN<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
// NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always
// have space to push
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <=
num_max_rdma_chunked_recv_tokens / 2);
}
size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
#ifndef DISABLE_ROCSHMEM
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
internode::get_source_meta_bytes();
#endif
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
sizeof(int64_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
kNumMaxScales * sizeof(float);
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}
size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
#ifndef DISABLE_ROCSHMEM
// Legacy mode
if (num_ranks <= NUM_MAX_NVL_PEERS)
return 0;
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % 2 == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes +=
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxScales * sizeof(float) * 2;
num_bytes +=
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
#else
EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install "
"rocSHMEM by following docs/install_dependencies.md");
#endif
}
};
struct LowLatencyBuffer {
int num_clean_int = 0;
void *dispatch_rdma_send_buffer = nullptr;
void *dispatch_rdma_recv_data_buffer = nullptr;
int *dispatch_rdma_recv_count_buffer = nullptr;
void *combine_rdma_send_buffer = nullptr;
void *combine_rdma_recv_data_buffer = nullptr;
int *combine_rdma_recv_flag_buffer = nullptr;
void *combine_rdma_send_buffer_data_start = nullptr;
size_t num_bytes_per_combine_msg = 0;
std::pair<int *, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int};
}
};
struct LowLatencyLayout {
size_t total_bytes = 0;
LowLatencyBuffer buffers[2];
template <typename out_ptr_t = void *, typename count_ptr_t = uint8_t *,
typename in_ptr_t = void *>
out_ptr_t advance(const in_ptr_t &ptr, size_t count) {
return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
}
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) {
const int num_scales = hidden / 128;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even receive buffers
// - 2 symmetric odd/even signaling buffers
// Message sizes
// NOTES: you should add a control `int4` for combine messages if you want to do data
// transformation
EP_HOST_ASSERT(num_scales * sizeof(float) <= static_cast<size_t>(hidden));
size_t num_bytes_per_dispatch_msg =
sizeof(int4) +
std::max(hidden * sizeof(hip_bfloat16), hidden + num_scales * sizeof(float));
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16);
// Send buffer
size_t dispatch_send_buffer_bytes =
num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_send_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);
total_bytes += send_buffer_bytes * 2;
// Symmetric receive buffers
// TODO: optimize memory usages
size_t dispatch_recv_data_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_recv_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t recv_buffer_bytes =
std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
total_bytes += recv_buffer_bytes * 2;
// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes =
std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
total_bytes += signaling_buffer_bytes_aligned * 2;
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// so you may see some parameters are duplicated
for (int i = 0; i < 2; ++i) {
buffers[i] = {
static_cast<int>(signaling_buffer_bytes / sizeof(int)),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
num_bytes_per_combine_msg};
}
}
};
inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) {
auto num_bytes =
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
.total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
NUM_BUFFER_ALIGNMENT_BYTES;
}
} // namespace deep_ep
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/HIPDataType.h>
#include <chrono>
#include <hip/hip_runtime.h>
#include <pybind11/functional.h>
#include <torch/python.h>
#include "../kernels/deep_ep/api.h"
#include "../kernels/deep_ep/configs.h"
#include "deep_ep.hpp"
namespace deep_ep {
Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy,
bool use_default_stream_as_comm_stream)
: rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes),
num_rdma_bytes(num_rdma_bytes), low_latency_mode(low_latency_mode),
explicitly_destroy(explicitly_destroy),
use_default_stream_as_comm_stream(use_default_stream_as_comm_stream),
comm_stream(use_default_stream_as_comm_stream
? at::hip::getCurrentHIPStreamMasqueradingAsCUDA()
: at::hip::getStreamFromPoolMasqueradingAsCUDA(true)) {
// Metadata memory
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void *);
int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int *);
// Common checks
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and
(num_nvl_bytes <= std::numeric_limits<int>::max() or num_rdma_bytes == 0));
EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and
(low_latency_mode or num_rdma_bytes <= std::numeric_limits<int>::max()));
EP_HOST_ASSERT(0 <= rank and rank < num_ranks and
(num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode));
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
if (num_rdma_bytes > 0)
EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode);
// Get ranks
EP_HOST_ASSERT(hipGetDevice(&device_id));
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS),
num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
#ifdef DISABLE_ROCSHMEM
EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and
"rocSHMEM is disabled during compilation, please install rocSHMEM by "
"following docs/install_dependencies.md");
#endif
// Get device info
hipDeviceProp_t device_prop = {};
EP_HOST_ASSERT(hipGetDeviceProperties(&device_prop, device_id));
num_device_sms = device_prop.multiProcessorCount;
if (num_nvl_bytes > 0) {
// Local IPC: alloc local memory and set local IPC handles
EP_HOST_ASSERT(hipExtMallocWithFlags(
&buffer_ptrs[nvl_rank],
num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes,
hipDeviceMallocUncached));
EP_HOST_ASSERT(hipIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
buffer_ptrs_gpu = reinterpret_cast<void **>(static_cast<uint8_t *>(buffer_ptrs[nvl_rank]) +
num_nvl_bytes + barrier_signal_bytes);
// Set barrier signals
barrier_signal_ptrs[nvl_rank] =
reinterpret_cast<int *>(static_cast<uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
barrier_signal_ptrs_gpu =
reinterpret_cast<int **>(static_cast<uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
barrier_signal_bytes + buffer_ptr_bytes);
// No need to synchronize, will do a full device sync during `sync`
EP_HOST_ASSERT(
hipMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream));
}
// Create 32 MiB workspace
EP_HOST_ASSERT(hipMalloc(&workspace, NUM_WORKSPACE_BYTES));
EP_HOST_ASSERT(hipMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream));
// MoE counter
EP_HOST_ASSERT(hipHostMalloc(&moe_recv_counter, sizeof(int64_t), hipHostMallocMapped));
EP_HOST_ASSERT(
hipHostGetDevicePointer(reinterpret_cast<void **>(&moe_recv_counter_mapped),
const_cast<int *>(moe_recv_counter), 0));
*moe_recv_counter = -1;
// MoE expert-level counter
EP_HOST_ASSERT(hipHostMalloc(&moe_recv_expert_counter,
sizeof(int) * NUM_MAX_LOCAL_EXPERTS, hipHostMallocMapped));
EP_HOST_ASSERT(
hipHostGetDevicePointer(reinterpret_cast<void **>(&moe_recv_expert_counter_mapped),
const_cast<int *>(moe_recv_expert_counter), 0));
for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++i)
moe_recv_expert_counter[i] = -1;
// MoE RDMA-level counter
if (num_rdma_ranks > 0) {
EP_HOST_ASSERT(
hipHostMalloc(&moe_recv_rdma_counter, sizeof(int), hipHostMallocMapped));
EP_HOST_ASSERT(
hipHostGetDevicePointer(reinterpret_cast<void **>(&moe_recv_rdma_counter_mapped),
const_cast<int *>(moe_recv_rdma_counter), 0));
*moe_recv_rdma_counter = -1;
}
}
Buffer::~Buffer() noexcept(false) {
if (not explicitly_destroy) {
destroy();
} else if (not destroyed) {
printf("WARNING: destroy() was not called before DeepEP buffer destruction, which can leak "
"resources.\n");
fflush(stdout);
}
}
bool Buffer::is_available() const {
return available;
}
bool Buffer::is_internode_available() const {
return is_available() and num_ranks > NUM_MAX_NVL_PEERS;
}
int Buffer::get_num_rdma_ranks() const {
return num_rdma_ranks;
}
int Buffer::get_rdma_rank() const {
return rdma_rank;
}
int Buffer::get_root_rdma_rank(bool global) const {
return global ? nvl_rank : 0;
}
int Buffer::get_local_device_id() const {
return device_id;
}
pybind11::bytearray Buffer::get_local_ipc_handle() const {
return {ipc_handles[nvl_rank].reserved, HIP_IPC_HANDLE_SIZE};
}
pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get ROCSHMEM unique ID");
auto unique_id = internode::get_unique_id();
return {reinterpret_cast<const char *>(unique_id.data()), unique_id.size()};
#else
EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install rocSHMEM by "
"following docs/install_dependencies.md");
#endif
}
torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object &dtype, int64_t offset,
bool use_rdma_buffer) const {
torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype);
auto element_bytes = static_cast<int64_t>(elementSize(casted_dtype));
auto base_ptr =
static_cast<uint8_t *>(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset;
auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes;
return torch::from_blob(base_ptr, num_bytes / element_bytes,
torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA));
}
torch::Stream Buffer::get_comm_stream() const {
return comm_stream;
}
void Buffer::destroy() {
EP_HOST_ASSERT(not destroyed);
// Synchronize
CUDA_CHECK(hipDeviceSynchronize());
if (num_nvl_bytes > 0) {
// Barrier
intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);
CUDA_CHECK(hipDeviceSynchronize());
// Close remote IPC
if (is_available()) {
for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank)
CUDA_CHECK(hipIpcCloseMemHandle(buffer_ptrs[i]));
}
// Free local buffer and error flag
CUDA_CHECK(hipFree(buffer_ptrs[nvl_rank]));
}
#ifndef DISABLE_ROCSHMEM
if (is_available() and num_rdma_bytes > 0) {
if (is_available() and num_rdma_bytes > 0) {
CUDA_CHECK(hipDeviceSynchronize());
internode::barrier();
internode::free(rdma_buffer_ptr);
internode::finalize();
}
#endif
// Free workspace and MoE counter
EP_HOST_ASSERT(hipFreeHost(const_cast<int *>(moe_recv_counter)));
// Free chunked mode staffs
destroyed = true;
available = false;
}
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
EP_HOST_ASSERT(not is_available());
EP_HOST_ASSERT(not is_available());
// Sync IPC handles
if (num_nvl_bytes > 0) {
EP_HOST_ASSERT(num_ranks == device_ids.size());
EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size());
EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value());
EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value());
auto handle_str = std::string(all_gathered_handles[offset + i].value());
if (offset + i != rank) {
if (offset + i != rank) {
EP_HOST_ASSERT(hipIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i],
hipIpcMemLazyEnablePeerAccess));
barrier_signal_ptrs[i] =
reinterpret_cast<int *>(static_cast<uint8_t *>(buffer_ptrs[i]) + num_nvl_bytes);
} else {
} else {
HIP_IPC_HANDLE_SIZE) == 0);
}
}
}
// Copy all buffer and barrier signal pointers to GPU
sizeof(void *) * NUM_MAX_NVL_PEERS,
hipMemcpyHostToDevice));
EP_HOST_ASSERT(hipMemcpy(barrier_signal_ptrs_gpu, barrier_signal_ptrs,
sizeof(int *) * NUM_MAX_NVL_PEERS, hipMemcpyHostToDevice));
EP_HOST_ASSERT(hipDeviceSynchronize());
}
}
#ifndef DISABLE_ROCSHMEM
if (num_rdma_bytes > 0) {
if (num_rdma_bytes > 0) {
// Initialize NVSHMEM
EP_HOST_ASSERT(root_unique_id_opt.has_value());
std::vector<uint8_t> root_unique_id(root_unique_id_opt->size());
std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size());
std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size());
auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks;
auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks;
internode::init(
root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode));
internode::barrier();
internode::barrier();
// Allocate
internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES);
// Clean buffer (mainly for low-latency mode)
// Barrier
internode::barrier();
}
}
#endif
// Ready to use
available = true;
}
std::optional<EventHandle>>
Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream) {
EP_HOST_ASSERT(topk_idx.dim() == 2);
EP_HOST_ASSERT(topk_idx.dim() == 2);
EP_HOST_ASSERT(topk_idx.is_contiguous());
EP_HOST_ASSERT(num_experts > 0);
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
if (allocate_on_comm_stream) {
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
}
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
}
}
num_topk = static_cast<int>(topk_idx.size(1));
auto num_tokens_per_rank =
torch::empty({num_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
auto num_tokens_per_rdma_rank = std::optional<torch::Tensor>();
auto num_tokens_per_rdma_rank = std::optional<torch::Tensor>();
{num_experts}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
auto is_token_in_rank = torch::empty(
{num_tokens, num_ranks}, torch::TensorOptions().dtype(torch::kBool).device(torch::kCUDA));
if (is_internode_available())
if (is_internode_available())
{num_rdma_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
topk_idx.data_ptr<int64_t>(), num_tokens_per_rank.data_ptr<int>(),
num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr<int>()
: nullptr,
num_tokens_per_expert.data_ptr<int>(), is_token_in_rank.data_ptr<bool>(), num_tokens,
num_topk, num_ranks, num_experts, comm_stream);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
t.record_stream(comm_stream);
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
to.has_value() ? to->record_stream(comm_stream) : void();
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
}
}
// Switch back compute stream
if (allocate_on_comm_stream)
event};
}
}
std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
Buffer::intranode_dispatch(
const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank, const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert, int cached_num_recv_tokens,
const std::optional<torch::Tensor> &cached_rank_prefix_matrix,
const std::optional<torch::Tensor> &cached_channel_prefix_matrix, int expert_alignment,
int num_worst_tokens, const Config &config, std::optional<EventHandle> &previous_event,
bool async, bool allocate_on_comm_stream) {
bool cached_mode = cached_rank_prefix_matrix.has_value();
bool cached_mode = cached_rank_prefix_matrix.has_value();
// receiving.
EP_HOST_ASSERT(config.num_sms % 2 == 0);
EP_HOST_ASSERT(config.num_sms % 2 == 0);
int num_channels = config.num_sms / 2;
if (cached_mode) {
EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value());
EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value());
} else {
EP_HOST_ASSERT(num_tokens_per_rank.has_value());
EP_HOST_ASSERT(num_tokens_per_expert.has_value());
}
// Type checks
EP_HOST_ASSERT(is_token_in_rank.scalar_type() == torch::kBool);
if (cached_mode) {
EP_HOST_ASSERT(cached_rank_prefix_matrix->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_channel_prefix_matrix->scalar_type() == torch::kInt32);
} else {
EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32);
}
// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0);
EP_HOST_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous());
is_token_in_rank.size(1) == num_ranks);
if (cached_mode) {
if (cached_mode) {
cached_rank_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_rank_prefix_matrix->size(0) == num_ranks and
cached_rank_prefix_matrix->size(1) == num_ranks);
EP_HOST_ASSERT(cached_channel_prefix_matrix->dim() == 2 and
cached_channel_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_channel_prefix_matrix->size(0) == num_ranks and
cached_channel_prefix_matrix->size(1) == num_channels);
} else {
} else {
num_tokens_per_expert->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0);
EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0);
EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS);
num_tokens_per_rank->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks);
EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks);
}
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
num_local_experts = num_experts / num_ranks;
// Top-k checks
int64_t *topk_idx_ptr = nullptr;
float *topk_weights_ptr = nullptr;
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
if (topk_idx.has_value()) {
num_topk = static_cast<int>(topk_idx->size(1));
EP_HOST_ASSERT(num_experts > 0);
EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous());
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));
EP_HOST_ASSERT(num_topk == topk_weights->size(1));
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_weights_ptr = topk_weights->data_ptr<float>();
topk_weights_ptr = topk_weights->data_ptr<float>();
}
// FP8 scales checks
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
if (x_scales.has_value()) {
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
x_scales->scalar_type() == torch::kInt);
EP_HOST_ASSERT(x_scales->dim() == 2);
EP_HOST_ASSERT(x_scales->dim() == 2);
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
x_scales_ptr = static_cast<float *>(x_scales->data_ptr());
scale_token_stride = static_cast<int>(x_scales->stride(0));
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
}
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
if (allocate_on_comm_stream) {
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
}
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
}
}
// Create handles (only return for non-cached mode)
auto rank_prefix_matrix = torch::Tensor();
auto channel_prefix_matrix = torch::Tensor();
std::vector<int> num_recv_tokens_per_expert_list;
std::vector<int> num_recv_tokens_per_expert_list;
{num_local_experts}, torch::TensorOptions().dtype(torch::kLong).device(torch::kCUDA));
// Barrier or send sizes
// To clean: channel start/end offset, head and tail
int num_memset_int = num_channels * num_ranks * 4;
if (cached_mode) {
rank_prefix_matrix = cached_rank_prefix_matrix.value();
channel_prefix_matrix = cached_channel_prefix_matrix.value();
channel_prefix_matrix = cached_channel_prefix_matrix.value();
// Copy rank prefix matrix and clean flags
rank_prefix_matrix.data_ptr<int>(), num_memset_int, buffer_ptrs_gpu,
barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream);
} else {
} else {
torch::empty({num_ranks, num_ranks},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
channel_prefix_matrix =
torch::empty({num_ranks, num_channels},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
// Send sizes
// Meta information:
// - Size prefix by ranks, shaped as `[num_ranks, num_ranks]`
// - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]`
// NOTES: no more token dropping in this version
*moe_recv_counter = -1;
moe_recv_expert_counter[i] = -1;
moe_recv_expert_counter[i] = -1;
num_nvl_bytes);
intranode::notify_dispatch(
num_tokens_per_rank->data_ptr<int>(), moe_recv_counter_mapped, num_ranks,
num_tokens_per_expert->data_ptr<int>(), moe_recv_expert_counter_mapped,
num_recv_tokens_per_expert.data_ptr<int64_t>(), num_experts, num_tokens,
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
rank_prefix_matrix.data_ptr<int>(), num_memset_int, expert_alignment, buffer_ptrs_gpu,
barrier_signal_ptrs_gpu, rank, comm_stream, num_channels);
if (num_worst_tokens > 0) {
// No CPU sync, just allocate the worst case
num_recv_tokens = num_worst_tokens;
// Must be forward with top-k stuffs
EP_HOST_ASSERT(topk_idx.has_value());
EP_HOST_ASSERT(topk_weights.has_value());
} else {
// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
while (true) {
// Read total count
num_recv_tokens = static_cast<int>(*moe_recv_counter);
// Read per-expert count
bool ready = (num_recv_tokens >= 0);
for (int i = 0; i < num_local_experts and ready; ++i)
ready &= moe_recv_expert_counter[i] >= 0;
if (ready)
break;
// Timeout check
std::chrono::high_resolution_clock::now() - start_time)
.count() > NUM_CPU_TIMEOUT_SECS)
throw std::runtime_error("DeepEP error: CPU recv timeout");
throw std::runtime_error("DeepEP error: CPU recv timeout");
}
moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
}
}
}
// Allocate new tensors
auto recv_src_idx = torch::empty(
{num_recv_tokens}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
auto recv_topk_idx = std::optional<torch::Tensor>(),
recv_topk_weights = std::optional<torch::Tensor>(),
recv_x_scales = std::optional<torch::Tensor>();
auto recv_channel_prefix_matrix =
torch::empty({num_ranks, num_channels},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
auto send_head = torch::empty({num_tokens, num_ranks},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
// Assign pointers
float *recv_topk_weights_ptr = nullptr;
float *recv_x_scales_ptr = nullptr;
if (topk_idx.has_value()) {
if (topk_idx.has_value()) {
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_idx_ptr = recv_topk_idx->data_ptr<int64_t>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
if (x_scales.has_value()) {
? torch::empty({num_recv_tokens}, x_scales->options())
: torch::empty({num_recv_tokens, num_scales}, x_scales->options());
recv_x_scales_ptr = static_cast<float *>(recv_x_scales->data_ptr());
}
}
// Dispatch
num_channels * num_ranks * sizeof(int) + // Channel start offset
num_channels * num_ranks * sizeof(int) + // Channel end offset
num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
hidden * recv_x.element_size() + // Data buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
sizeof(int) + // Source index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
num_topk * sizeof(int64_t) + // Top-k index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
num_topk * sizeof(float) + // Top-k weight buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
sizeof(float) * num_scales // FP8 scale buffer
<= num_nvl_bytes);
intranode::dispatch(
recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr<int>(), recv_topk_idx_ptr,
recv_topk_weights_ptr, recv_channel_prefix_matrix.data_ptr<int>(),
send_head.data_ptr<int>(), x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(), num_tokens,
num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk,
num_experts, num_scales, scale_token_stride, scale_hidden_stride, buffer_ptrs_gpu, rank,
num_ranks, comm_stream, config.num_sms, config.num_max_nvl_chunked_send_tokens,
config.num_max_nvl_chunked_recv_tokens);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
recv_src_idx, recv_channel_prefix_matrix, send_head}) {
t.record_stream(comm_stream);
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
{x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert,
cached_channel_prefix_matrix, cached_rank_prefix_matrix, recv_topk_idx,
recv_topk_weights, recv_x_scales}) {
to.has_value() ? to->record_stream(comm_stream) : void();
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
}
}
// Switch back compute stream
if (allocate_on_comm_stream)
// Return values
recv_x_scales,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
num_recv_tokens_per_expert,
rank_prefix_matrix,
channel_prefix_matrix,
recv_channel_prefix_matrix,
recv_src_idx,
send_head,
event};
}
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
const std::optional<torch::Tensor> &bias_0,
const std::optional<torch::Tensor> &bias_1, const torch::Tensor &src_idx,
const torch::Tensor &rank_prefix_matrix,
const torch::Tensor &channel_prefix_matrix,
const torch::Tensor &send_head, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream) {
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
src_idx.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and
send_head.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and
rank_prefix_matrix.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and
channel_prefix_matrix.is_contiguous() and
channel_prefix_matrix.scalar_type() == torch::kInt32);
// receiving.
EP_HOST_ASSERT(config.num_sms % 2 == 0);
EP_HOST_ASSERT(config.num_sms % 2 == 0);
int num_channels = config.num_sms / 2;
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_recv_tokens = static_cast<int>(send_head.size(0));
EP_HOST_ASSERT(src_idx.size(0) == num_tokens);
EP_HOST_ASSERT(send_head.size(1) == num_ranks);
rank_prefix_matrix.size(1) == num_ranks);
EP_HOST_ASSERT(channel_prefix_matrix.size(0) == num_ranks and
channel_prefix_matrix.size(1) == num_channels);
EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0);
EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0);
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
if (allocate_on_comm_stream) {
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
}
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
}
}
auto recv_topk_weights = std::optional<torch::Tensor>();
float *topk_weights_ptr = nullptr;
float *recv_topk_weights_ptr = nullptr;
if (topk_weights.has_value()) {
if (topk_weights.has_value()) {
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(topk_weights->size(0) == num_tokens);
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_weights_ptr = topk_weights->data_ptr<float>();
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
// Launch barrier and reset queue head and tail
EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes);
buffer_ptrs_gpu, send_head.data_ptr<int>(), num_channels, num_recv_tokens,
num_channels * num_ranks * 2, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream);
// Assign bias pointers
void *bias_ptrs[2] = {nullptr, nullptr};
for (int i = 0; i < 2; ++i)
if (bias_opts[i].has_value()) {
auto bias = bias_opts[i].value();
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden);
bias_ptrs[i] = bias.data_ptr();
}
// Combine data
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
hidden * x.element_size() + // Data buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
sizeof(int) + // Source index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
num_topk * sizeof(float) // Top-k weight buffer
<= num_nvl_bytes);
intranode::combine(
at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), recv_x.data_ptr(),
recv_topk_weights_ptr, x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1],
src_idx.data_ptr<int>(), rank_prefix_matrix.data_ptr<int>(),
channel_prefix_matrix.data_ptr<int>(), send_head.data_ptr<int>(), num_tokens,
num_recv_tokens, hidden, num_topk, buffer_ptrs_gpu, rank, num_ranks, comm_stream,
config.num_sms, config.num_max_nvl_chunked_send_tokens,
config.num_max_nvl_chunked_recv_tokens);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
t.record_stream(comm_stream);
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
to.has_value() ? to->record_stream(comm_stream) : void();
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
}
}
// Switch back compute stream
if (allocate_on_comm_stream)
return {recv_x, recv_topk_weights, event};
}
std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor,
std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<EventHandle>>
Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx,
const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank,
const std::optional<torch::Tensor> &num_tokens_per_rdma_rank,
const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor> &cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor> &cached_gbl_channel_prefix_matrix,
const std::optional<torch::Tensor> &cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream) {
#ifndef DISABLE_ROCSHMEM
// In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks,
// which can be quite long. If users of DeepEP need to execute other Python code on other
// threads, such as KV transfer, their code will get stuck due to GIL unless we release GIL
// here.
pybind11::gil_scoped_release release;
pybind11::gil_scoped_release release;
const int num_channels = config.num_sms / 2;
EP_HOST_ASSERT(config.num_sms % 2 == 0);
EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS);
bool cached_mode = cached_rdma_channel_prefix_matrix.has_value();
if (cached_mode) {
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value());
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value());
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value());
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value());
} else {
EP_HOST_ASSERT(num_tokens_per_rank.has_value());
EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value());
EP_HOST_ASSERT(num_tokens_per_expert.has_value());
}
// Type checks
if (cached_mode) {
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == torch::kInt32);
} else {
EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32);
}
// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0);
if (cached_mode) {
cached_rdma_channel_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and
cached_rdma_channel_prefix_matrix->size(1) == num_channels);
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and
cached_recv_rdma_rank_prefix_sum->is_contiguous());
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks);
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks);
cached_gbl_channel_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and
cached_gbl_channel_prefix_matrix->size(1) == num_channels);
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and
cached_recv_gbl_rank_prefix_sum->is_contiguous());
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks);
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks);
} else {
num_tokens_per_rank->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 and
num_tokens_per_rdma_rank->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and
num_tokens_per_expert->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks);
EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks);
EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks);
EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0);
EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS);
}
hidden_int4 = static_cast<int>(x.size(1) * x.element_size() / sizeof(int4));
auto num_experts = cached_mode ? 0 : static_cast<int>(num_tokens_per_expert->size(0)),
num_local_experts = num_experts / num_ranks;
// Top-k checks
int64_t *topk_idx_ptr = nullptr;
float *topk_weights_ptr = nullptr;
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
if (topk_idx.has_value()) {
num_topk = static_cast<int>(topk_idx->size(1));
EP_HOST_ASSERT(num_experts > 0);
EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous());
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));
EP_HOST_ASSERT(num_topk == topk_weights->size(1));
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_weights_ptr = topk_weights->data_ptr<float>();
topk_weights_ptr = topk_weights->data_ptr<float>();
}
// FP8 scales checks
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
if (x_scales.has_value()) {
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
x_scales->scalar_type() == torch::kInt);
EP_HOST_ASSERT(x_scales->dim() == 2);
EP_HOST_ASSERT(x_scales->dim() == 2);
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
x_scales_ptr = static_cast<float *>(x_scales->data_ptr());
scale_token_stride = static_cast<int>(x_scales->stride(0));
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
}
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
// Create handles (only return for non-cached mode)
auto rdma_channel_prefix_matrix = torch::Tensor();
auto recv_rdma_rank_prefix_sum = torch::Tensor();
auto gbl_channel_prefix_matrix = torch::Tensor();
auto recv_gbl_rank_prefix_sum = torch::Tensor();
std::vector<int> num_recv_tokens_per_expert_list;
std::vector<int> num_recv_tokens_per_expert_list;
// Barrier or send sizes
if (cached_mode) {
num_rdma_recv_tokens = cached_num_rdma_recv_tokens;
rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value();
rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value();
gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value();
recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value();
// Just a barrier and clean flags
hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, 0, nullptr,
nullptr, nullptr, nullptr, rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank,
comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, true, low_latency_mode);
} else {
} else {
torch::empty({num_rdma_ranks, num_channels},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
recv_rdma_rank_prefix_sum = torch::empty(
{num_rdma_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
gbl_channel_prefix_matrix =
torch::empty({num_ranks, num_channels},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
recv_gbl_rank_prefix_sum = torch::empty(
{num_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
// Send sizes
*moe_recv_counter = -1, *moe_recv_rdma_counter = -1;
moe_recv_expert_counter[i] = -1;
moe_recv_expert_counter[i] = -1;
num_tokens_per_rank->data_ptr<int>(), moe_recv_counter_mapped, num_ranks,
num_tokens_per_rdma_rank->data_ptr<int>(), moe_recv_rdma_counter_mapped,
num_tokens_per_expert->data_ptr<int>(), moe_recv_expert_counter_mapped, num_experts,
is_token_in_rank.data_ptr<bool>(), num_tokens, num_channels, hidden_int4, num_scales,
num_topk, expert_alignment, rdma_channel_prefix_matrix.data_ptr<int>(),
recv_rdma_rank_prefix_sum.data_ptr<int>(), gbl_channel_prefix_matrix.data_ptr<int>(),
recv_gbl_rank_prefix_sum.data_ptr<int>(), rdma_buffer_ptr,
config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu,
config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes,
low_latency_mode);
// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
while (true) {
// Read total count
num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);
num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);
// Read per-expert count
bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0);
ready &= moe_recv_expert_counter[i] >= 0;
ready &= moe_recv_expert_counter[i] >= 0;
if (ready)
break;
// Timeout check
std::chrono::high_resolution_clock::now() - start_time)
.count() > NUM_CPU_TIMEOUT_SECS) {
printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank,
num_recv_tokens, num_rdma_recv_tokens);
for (int i = 0; i < num_local_experts; ++i)
printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]);
printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]);
throw std::runtime_error("DeepEP error: timeout (dispatch CPU)");
}
}
std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
}
}
// Allocate new tensors
auto recv_topk_idx = std::optional<torch::Tensor>(),
recv_topk_weights = std::optional<torch::Tensor>(),
recv_x_scales = std::optional<torch::Tensor>();
auto recv_src_meta = std::optional<torch::Tensor>();
auto recv_rdma_channel_prefix_matrix = std::optional<torch::Tensor>();
auto recv_rdma_channel_prefix_matrix = std::optional<torch::Tensor>();
auto send_rdma_head = std::optional<torch::Tensor>();
auto send_nvl_head = std::optional<torch::Tensor>();
if (not cached_mode) {
if (not cached_mode) {
{num_recv_tokens, internode::get_source_meta_bytes()},
torch::TensorOptions().dtype(torch::kByte).device(torch::kCUDA));
recv_rdma_channel_prefix_matrix =
torch::empty({num_rdma_ranks, num_channels},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
recv_gbl_channel_prefix_matrix =
torch::empty({num_ranks, num_channels},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
send_rdma_head =
torch::empty({num_tokens, num_rdma_ranks},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
send_nvl_head =
torch::empty({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
}
}
// Assign pointers
float *recv_topk_weights_ptr = nullptr;
float *recv_x_scales_ptr = nullptr;
if (topk_idx.has_value()) {
if (topk_idx.has_value()) {
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_idx_ptr = recv_topk_idx->data_ptr<int64_t>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
if (x_scales.has_value()) {
? torch::empty({num_recv_tokens}, x_scales->options())
: torch::empty({num_recv_tokens, num_scales}, x_scales->options());
recv_x_scales_ptr = static_cast<float *>(recv_x_scales->data_ptr());
}
}
// Launch data dispatch
// NOTES: the buffer size checks are moved into the `.cu` file
recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr,
cached_mode ? nullptr : recv_src_meta->data_ptr(), x.data_ptr(), x_scales_ptr, topk_idx_ptr,
topk_weights_ptr, cached_mode ? nullptr : send_rdma_head->data_ptr<int>(),
cached_mode ? nullptr : send_nvl_head->data_ptr<int>(),
cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr<int>(),
cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr<int>(),
rdma_channel_prefix_matrix.data_ptr<int>(), recv_rdma_rank_prefix_sum.data_ptr<int>(),
gbl_channel_prefix_matrix.data_ptr<int>(), recv_gbl_rank_prefix_sum.data_ptr<int>(),
is_token_in_rank.data_ptr<bool>(), num_tokens, hidden_int4, num_scales, num_topk,
num_experts, scale_token_stride, scale_hidden_stride, rdma_buffer_ptr,
config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens,
config.num_max_nvl_chunked_recv_tokens, rank, num_ranks, cached_mode, comm_stream,
num_channels, low_latency_mode);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
{x, is_token_in_rank, recv_x, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum,
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) {
t.record_stream(comm_stream);
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
{x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_rdma_rank,
num_tokens_per_expert, cached_rdma_channel_prefix_matrix,
cached_recv_rdma_rank_prefix_sum, cached_gbl_channel_prefix_matrix,
cached_recv_gbl_rank_prefix_sum, recv_topk_idx, recv_topk_weights, recv_x_scales,
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, send_rdma_head,
send_nvl_head, recv_src_meta}) {
to.has_value() ? to->record_stream(comm_stream) : void();
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
// Return values
recv_x_scales,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
rdma_channel_prefix_matrix,
gbl_channel_prefix_matrix,
recv_rdma_channel_prefix_matrix,
recv_rdma_rank_prefix_sum,
recv_gbl_channel_prefix_matrix,
recv_gbl_rank_prefix_sum,
recv_src_meta,
send_rdma_head,
send_nvl_head,
event};
#else
#else
"following docs/install_dependencies.md");
return {};
return {};
#endif
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &bias_0, const std::optional<torch::Tensor> &bias_1,
const torch::Tensor &src_meta, const torch::Tensor &is_combined_token_in_rank,
const torch::Tensor &rdma_channel_prefix_matrix, const torch::Tensor &rdma_rank_prefix_sum,
const torch::Tensor &gbl_channel_prefix_matrix, const torch::Tensor &combined_rdma_head,
const torch::Tensor &combined_nvl_head, const Config &config,
std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream) {
#ifndef DISABLE_ROCSHMEM
const int num_channels = config.num_sms / 2;
const int num_channels = config.num_sms / 2;
EP_HOST_ASSERT(config.num_sms % 2 == 0);
// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
src_meta.scalar_type() == torch::kByte);
EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and
is_combined_token_in_rank.is_contiguous() and
is_combined_token_in_rank.scalar_type() == torch::kBool);
EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and
rdma_channel_prefix_matrix.is_contiguous() and
rdma_channel_prefix_matrix.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and
rdma_rank_prefix_sum.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and
gbl_channel_prefix_matrix.is_contiguous() and
gbl_channel_prefix_matrix.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and
combined_rdma_head.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and
combined_nvl_head.scalar_type() == torch::kInt32);
hidden_int4 = static_cast<int>(x.size(1) * x.element_size() / sizeof(int4));
auto num_combined_tokens = static_cast<int>(is_combined_token_in_rank.size(0));
auto num_combined_tokens = static_cast<int>(is_combined_token_in_rank.size(0));
EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0);
internode::get_source_meta_bytes());
EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks);
EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks);
rdma_channel_prefix_matrix.size(1) == num_channels);
EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks);
EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks);
gbl_channel_prefix_matrix.size(1) == num_channels);
EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and
combined_rdma_head.size(0) == num_combined_tokens and
combined_rdma_head.size(1) == num_rdma_ranks);
EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and
combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS);
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
// Top-k checks
auto combined_topk_weights = std::optional<torch::Tensor>();
float *topk_weights_ptr = nullptr;
float *combined_topk_weights_ptr = nullptr;
if (topk_weights.has_value()) {
if (topk_weights.has_value()) {
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(topk_weights->size(0) == num_tokens);
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_weights_ptr = topk_weights->data_ptr<float>();
topk_weights_ptr = topk_weights->data_ptr<float>();
torch::empty({num_combined_tokens, num_topk}, topk_weights->options());
combined_topk_weights_ptr = combined_topk_weights->data_ptr<float>();
combined_topk_weights_ptr = combined_topk_weights->data_ptr<float>();
}
// Extra check for avoid-dead-lock design
EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks);
// Launch barrier and reset queue head and tail
hidden_int4, 0, 0, num_topk, num_ranks, num_channels, num_combined_tokens,
combined_rdma_head.data_ptr<int>(), rdma_channel_prefix_matrix.data_ptr<int>(),
rdma_rank_prefix_sum.data_ptr<int>(), combined_nvl_head.data_ptr<int>(), rdma_buffer_ptr,
config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu,
config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes,
false, low_latency_mode);
// Assign bias pointers
void *bias_ptrs[2] = {nullptr, nullptr};
for (int i = 0; i < 2; ++i)
if (bias_opts[i].has_value()) {
EP_HOST_ASSERT(false and "bias is not supported in internode combine");
auto bias = bias_opts[i].value();
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
EP_HOST_ASSERT(bias.size(0) == num_combined_tokens and bias.size(1) == hidden);
bias_ptrs[i] = bias.data_ptr();
}
// Launch data combine
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), combined_x.data_ptr(),
combined_topk_weights_ptr, is_combined_token_in_rank.data_ptr<bool>(), x.data_ptr(),
topk_weights_ptr, bias_ptrs[0], bias_ptrs[1], combined_rdma_head.data_ptr<int>(),
combined_nvl_head.data_ptr<int>(), src_meta.data_ptr(),
rdma_channel_prefix_matrix.data_ptr<int>(), rdma_rank_prefix_sum.data_ptr<int>(),
gbl_channel_prefix_matrix.data_ptr<int>(), num_tokens, num_combined_tokens, hidden,
num_topk, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens,
config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu,
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, rank,
num_ranks, comm_stream, num_channels, low_latency_mode);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
rdma_rank_prefix_sum, gbl_channel_prefix_matrix, combined_x,
combined_rdma_head, combined_nvl_head}) {
t.record_stream(comm_stream);
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
to.has_value() ? to->record_stream(comm_stream) : void();
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
// Return values
return {combined_x, combined_topk_weights, event};
#else
"following docs/install_dependencies.md");
return {};
return {};
#endif
}
int num_experts) {
EP_HOST_ASSERT(false and "not support low latency");
}
}
std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_idx,
const std::optional<torch::Tensor> &cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor> &dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8,
bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook) {
EP_HOST_ASSERT(false and "not support low latency");
return {};
return {};
}
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
const torch::Tensor &topk_weights, const torch::Tensor &src_info,
const torch::Tensor &layout_range,
const std::optional<torch::Tensor> &combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor> &out) {
EP_HOST_ASSERT(false and "not support low latency");
return {};
return {};
}
int hidden, int num_experts) const {
EP_HOST_ASSERT(false and "not support low latency");
return {};
return {};
} // namespace primus_turbo::pytorch::deep_ep
// #include <ATen/dtk_macros.h>
#include <ATen/hip/HIPContext.h> #include <ATen/hip/HIPContext.h>
#include <ATen/hip/HIPDataType.h> #include <ATen/hip/HIPDataType.h>
#include <chrono> #include <chrono>
...@@ -5,8 +6,8 @@ ...@@ -5,8 +6,8 @@
#include <pybind11/functional.h> #include <pybind11/functional.h>
#include <torch/python.h> #include <torch/python.h>
#include "./kernels/api.cuh" #include "kernels/api.cuh"
#include "./kernels/configs.cuh" #include "kernels/configs.cuh"
#include "deep_ep.hpp" #include "deep_ep.hpp"
namespace deep_ep { namespace deep_ep {
...@@ -40,8 +41,8 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ ...@@ -40,8 +41,8 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
// Get ranks // Get ranks
CUDA_CHECK(hipGetDevice(&device_id)); CUDA_CHECK(hipGetDevice(&device_id));
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_rdma_ranks = ::max(1, num_ranks / NUM_MAX_NVL_PEERS),
num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); num_nvl_ranks = ::min(num_ranks, NUM_MAX_NVL_PEERS);
#ifdef DISABLE_ROCSHMEM #ifdef DISABLE_ROCSHMEM
EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and
...@@ -803,8 +804,8 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te ...@@ -803,8 +804,8 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
// here. // here.
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;
const int num_channels = config.num_sms / 2; const int num_channels = config.num_sms / 3;
EP_HOST_ASSERT(config.num_sms % 2 == 0); EP_HOST_ASSERT(config.num_sms % 3 == 0);
EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS);
bool cached_mode = cached_rdma_channel_prefix_matrix.has_value(); bool cached_mode = cached_rdma_channel_prefix_matrix.has_value();
...@@ -901,10 +902,10 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te ...@@ -901,10 +902,10 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
// Allocate all tensors on comm stream if set // Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront! // NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream(); auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (allocate_on_comm_stream) { if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async); EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream); at::hip::setCurrentHIPStreamMasqueradingAsCUDA(comm_stream);
} }
// Wait previous tasks to be finished // Wait previous tasks to be finished
...@@ -1088,7 +1089,7 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te ...@@ -1088,7 +1089,7 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
// Switch back compute stream // Switch back compute stream
if (allocate_on_comm_stream) if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream); at::hip::setCurrentHIPStreamMasqueradingAsCUDA(compute_stream);
// Return values // Return values
return {recv_x, return {recv_x,
...@@ -1124,8 +1125,8 @@ Buffer::internode_combine( ...@@ -1124,8 +1125,8 @@ Buffer::internode_combine(
const torch::Tensor &combined_nvl_head, const Config &config, const torch::Tensor &combined_nvl_head, const Config &config,
std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream) { std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream) {
#ifndef DISABLE_ROCSHMEM #ifndef DISABLE_ROCSHMEM
const int num_channels = config.num_sms / 2; const int num_channels = config.num_sms / 3;
EP_HOST_ASSERT(config.num_sms % 2 == 0); EP_HOST_ASSERT(config.num_sms % 3 == 0);
// Shape and contiguous checks // Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
...@@ -1167,10 +1168,10 @@ Buffer::internode_combine( ...@@ -1167,10 +1168,10 @@ Buffer::internode_combine(
// Allocate all tensors on comm stream if set // Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront! // NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream(); auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (allocate_on_comm_stream) { if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async); EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream); at::hip::setCurrentHIPStreamMasqueradingAsCUDA(comm_stream);
} }
// Wait previous tasks to be finished // Wait previous tasks to be finished
...@@ -1216,7 +1217,7 @@ Buffer::internode_combine( ...@@ -1216,7 +1217,7 @@ Buffer::internode_combine(
void *bias_ptrs[2] = {nullptr, nullptr}; void *bias_ptrs[2] = {nullptr, nullptr};
for (int i = 0; i < 2; ++i) for (int i = 0; i < 2; ++i)
if (bias_opts[i].has_value()) { if (bias_opts[i].has_value()) {
// EP_HOST_ASSERT(false and "bias is not supported in internode combine"); EP_HOST_ASSERT(false and "bias is not supported in internode combine");
auto bias = bias_opts[i].value(); auto bias = bias_opts[i].value();
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
...@@ -1260,7 +1261,7 @@ Buffer::internode_combine( ...@@ -1260,7 +1261,7 @@ Buffer::internode_combine(
// Switch back compute stream // Switch back compute stream
if (allocate_on_comm_stream) if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream); at::hip::setCurrentHIPStreamMasqueradingAsCUDA(compute_stream);
// Return values // Return values
return {combined_x, combined_topk_weights, event}; return {combined_x, combined_topk_weights, event};
......
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/HIPDataType.h>
#include <chrono>
#include <hip/hip_runtime.h>
#include <pybind11/functional.h>
#include <torch/python.h>
#include "kernels/api.cuh"
#include "kernels/configs.cuh"
#include "deep_ep_hip.hpp"
namespace deep_ep {
Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy,
bool use_default_stream_as_comm_stream)
: rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes),
num_rdma_bytes(num_rdma_bytes), low_latency_mode(low_latency_mode),
explicitly_destroy(explicitly_destroy),
use_default_stream_as_comm_stream(use_default_stream_as_comm_stream),
comm_stream(use_default_stream_as_comm_stream
? at::hip::getCurrentHIPStreamMasqueradingAsCUDA()
: at::hip::getStreamFromPoolMasqueradingAsCUDA(true)) {
// Metadata memory
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void *);
int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int *);
// Common checks
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and
(num_nvl_bytes <= std::numeric_limits<int>::max() or num_rdma_bytes == 0));
EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and
(low_latency_mode or num_rdma_bytes <= std::numeric_limits<int>::max()));
EP_HOST_ASSERT(0 <= rank and rank < num_ranks and
(num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode));
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
if (num_rdma_bytes > 0)
EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode);
// Get ranks
CUDA_CHECK(hipGetDevice(&device_id));
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
num_rdma_ranks = ::max(1, num_ranks / NUM_MAX_NVL_PEERS),
num_nvl_ranks = ::min(num_ranks, NUM_MAX_NVL_PEERS);
#ifdef DISABLE_ROCSHMEM
EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and
"rocSHMEM is disabled during compilation, please install rocSHMEM by "
"following docs/install_dependencies.md");
#endif
// Get device info
hipDeviceProp_t device_prop = {};
CUDA_CHECK(hipGetDeviceProperties(&device_prop, device_id));
num_device_sms = device_prop.multiProcessorCount;
if (num_nvl_bytes > 0) {
// Local IPC: alloc local memory and set local IPC handles
CUDA_CHECK(hipExtMallocWithFlags(
&buffer_ptrs[nvl_rank],
num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes,
hipDeviceMallocUncached));
CUDA_CHECK(hipIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
buffer_ptrs_gpu = reinterpret_cast<void **>(static_cast<uint8_t *>(buffer_ptrs[nvl_rank]) +
num_nvl_bytes + barrier_signal_bytes);
// Set barrier signals
barrier_signal_ptrs[nvl_rank] =
reinterpret_cast<int *>(static_cast<uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
barrier_signal_ptrs_gpu =
reinterpret_cast<int **>(static_cast<uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
barrier_signal_bytes + buffer_ptr_bytes);
// No need to synchronize, will do a full device sync during `sync`
CUDA_CHECK(
hipMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream));
}
// Create 32 MiB workspace
CUDA_CHECK(hipMalloc(&workspace, NUM_WORKSPACE_BYTES));
CUDA_CHECK(hipMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream));
// MoE counter
CUDA_CHECK(hipHostMalloc(&moe_recv_counter, sizeof(int64_t), hipHostMallocMapped));
CUDA_CHECK(
hipHostGetDevicePointer(reinterpret_cast<void **>(&moe_recv_counter_mapped),
const_cast<int *>(moe_recv_counter), 0));
*moe_recv_counter = -1;
// MoE expert-level counter
CUDA_CHECK(hipHostMalloc(&moe_recv_expert_counter,
sizeof(int) * NUM_MAX_LOCAL_EXPERTS, hipHostMallocMapped));
CUDA_CHECK(
hipHostGetDevicePointer(reinterpret_cast<void **>(&moe_recv_expert_counter_mapped),
const_cast<int *>(moe_recv_expert_counter), 0));
for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++i)
moe_recv_expert_counter[i] = -1;
// MoE RDMA-level counter
if (num_rdma_ranks > 0) {
CUDA_CHECK(
hipHostMalloc(&moe_recv_rdma_counter, sizeof(int), hipHostMallocMapped));
CUDA_CHECK(
hipHostGetDevicePointer(reinterpret_cast<void **>(&moe_recv_rdma_counter_mapped),
const_cast<int *>(moe_recv_rdma_counter), 0));
*moe_recv_rdma_counter = -1;
}
}
Buffer::~Buffer() noexcept(false) {
if (not explicitly_destroy) {
destroy();
} else if (not destroyed) {
printf("WARNING: destroy() was not called before DeepEP buffer destruction, which can leak "
"resources.\n");
fflush(stdout);
}
}
bool Buffer::is_available() const {
return available;
}
bool Buffer::is_internode_available() const {
return is_available() and num_ranks > NUM_MAX_NVL_PEERS;
}
int Buffer::get_num_rdma_ranks() const {
return num_rdma_ranks;
}
int Buffer::get_rdma_rank() const {
return rdma_rank;
}
int Buffer::get_root_rdma_rank(bool global) const {
return global ? nvl_rank : 0;
}
int Buffer::get_local_device_id() const {
return device_id;
}
pybind11::bytearray Buffer::get_local_ipc_handle() const {
return {ipc_handles[nvl_rank].reserved, HIP_IPC_HANDLE_SIZE};
}
pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get ROCSHMEM unique ID");
auto unique_id = internode::get_unique_id();
return {reinterpret_cast<const char *>(unique_id.data()), unique_id.size()};
#else
EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install rocSHMEM by "
"following docs/install_dependencies.md");
#endif
}
torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object &dtype, int64_t offset,
bool use_rdma_buffer) const {
torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype);
auto element_bytes = static_cast<int64_t>(elementSize(casted_dtype));
auto base_ptr =
static_cast<uint8_t *>(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset;
auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes;
return torch::from_blob(base_ptr, num_bytes / element_bytes,
torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA));
}
torch::Stream Buffer::get_comm_stream() const {
return comm_stream;
}
void Buffer::destroy() {
EP_HOST_ASSERT(not destroyed);
// Synchronize
CUDA_CHECK(hipDeviceSynchronize());
if (num_nvl_bytes > 0) {
// Barrier
intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks,
comm_stream);
CUDA_CHECK(hipDeviceSynchronize());
// Close remote IPC
if (is_available()) {
for (int i = 0; i < num_nvl_ranks; ++i)
if (i != nvl_rank)
CUDA_CHECK(hipIpcCloseMemHandle(buffer_ptrs[i]));
}
// Free local buffer and error flag
CUDA_CHECK(hipFree(buffer_ptrs[nvl_rank]));
}
// Free ROCSHMEM
#ifndef DISABLE_ROCSHMEM
if (is_available() and num_rdma_bytes > 0) {
CUDA_CHECK(hipDeviceSynchronize());
internode::barrier();
internode::free(rdma_buffer_ptr);
internode::finalize();
}
#endif
// Free workspace and MoE counter
CUDA_CHECK(hipFree(workspace));
CUDA_CHECK(hipFreeHost(const_cast<int *>(moe_recv_counter)));
// Free chunked mode staffs
CUDA_CHECK(hipFreeHost(const_cast<int *>(moe_recv_expert_counter)));
destroyed = true;
available = false;
}
void Buffer::sync(const std::vector<int> &device_ids,
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
const std::optional<pybind11::bytearray> &root_unique_id_opt) {
EP_HOST_ASSERT(not is_available());
// Sync IPC handles
if (num_nvl_bytes > 0) {
EP_HOST_ASSERT(num_ranks == device_ids.size());
EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size());
for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++i) {
EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value());
auto handle_str = std::string(all_gathered_handles[offset + i].value());
EP_HOST_ASSERT(handle_str.size() == HIP_IPC_HANDLE_SIZE);
if (offset + i != rank) {
std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), HIP_IPC_HANDLE_SIZE);
CUDA_CHECK(hipIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i],
hipIpcMemLazyEnablePeerAccess));
barrier_signal_ptrs[i] =
reinterpret_cast<int *>(static_cast<uint8_t *>(buffer_ptrs[i]) + num_nvl_bytes);
} else {
EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(),
HIP_IPC_HANDLE_SIZE) == 0);
}
}
// Copy all buffer and barrier signal pointers to GPU
CUDA_CHECK(hipMemcpy(buffer_ptrs_gpu, buffer_ptrs,
sizeof(void *) * NUM_MAX_NVL_PEERS,
hipMemcpyHostToDevice));
CUDA_CHECK(hipMemcpy(barrier_signal_ptrs_gpu, barrier_signal_ptrs,
sizeof(int *) * NUM_MAX_NVL_PEERS, hipMemcpyHostToDevice));
CUDA_CHECK(hipDeviceSynchronize());
}
#ifndef DISABLE_ROCSHMEM
// Sync ROCSHMEM handles and allocate memory
if (num_rdma_bytes > 0) {
// Initialize NVSHMEM
EP_HOST_ASSERT(root_unique_id_opt.has_value());
std::vector<uint8_t> root_unique_id(root_unique_id_opt->size());
auto root_unique_id_str = root_unique_id_opt->cast<std::string>();
std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size());
auto nvshmem_rank = low_latency_mode ? rank : rdma_rank;
auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks;
EP_HOST_ASSERT(nvshmem_rank ==
internode::init(
root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode));
internode::barrier();
// Allocate
rdma_buffer_ptr =
internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES);
// Clean buffer (mainly for low-latency mode)
CUDA_CHECK(hipMemset(rdma_buffer_ptr, 0, num_rdma_bytes));
// Barrier
internode::barrier();
CUDA_CHECK(hipDeviceSynchronize());
}
#endif
// Ready to use
available = true;
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
std::optional<EventHandle>>
Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream) {
EP_HOST_ASSERT(topk_idx.dim() == 2);
EP_HOST_ASSERT(topk_idx.is_contiguous());
EP_HOST_ASSERT(num_experts > 0);
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::hip::setCurrentHIPStreamMasqueradingAsCUDA(comm_stream);
}
if (not use_default_stream_as_comm_stream) {
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
}
auto num_tokens = static_cast<int>(topk_idx.size(0)),
num_topk = static_cast<int>(topk_idx.size(1));
auto num_tokens_per_rank =
torch::empty({num_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
auto num_tokens_per_rdma_rank = std::optional<torch::Tensor>();
auto num_tokens_per_expert = torch::empty(
{num_experts}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
auto is_token_in_rank = torch::empty(
{num_tokens, num_ranks}, torch::TensorOptions().dtype(torch::kBool).device(torch::kCUDA));
if (is_internode_available())
num_tokens_per_rdma_rank = torch::empty(
{num_rdma_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
layout::get_dispatch_layout(
topk_idx.data_ptr<int64_t>(), num_tokens_per_rank.data_ptr<int>(),
num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr<int>()
: nullptr,
num_tokens_per_expert.data_ptr<int>(), is_token_in_rank.data_ptr<bool>(), num_tokens,
num_topk, num_ranks, num_experts, comm_stream);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto &t : {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto &to : {num_tokens_per_rdma_rank}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
if (not use_default_stream_as_comm_stream) {
stream_wait(compute_stream, comm_stream);
}
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::hip::setCurrentHIPStreamMasqueradingAsCUDA(compute_stream);
return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank,
event};
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
Buffer::intranode_dispatch(
const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank, const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert, int cached_num_recv_tokens,
const std::optional<torch::Tensor> &cached_rank_prefix_matrix,
const std::optional<torch::Tensor> &cached_channel_prefix_matrix, int expert_alignment,
int num_worst_tokens, const Config &config, std::optional<EventHandle> &previous_event,
bool async, bool allocate_on_comm_stream) {
bool cached_mode = cached_rank_prefix_matrix.has_value();
// One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for
// receiving.
EP_HOST_ASSERT(config.num_sms % 2 == 0);
int num_channels = config.num_sms / 2;
if (cached_mode) {
EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value());
EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value());
} else {
EP_HOST_ASSERT(num_tokens_per_rank.has_value());
EP_HOST_ASSERT(num_tokens_per_expert.has_value());
}
// Type checks
EP_HOST_ASSERT(is_token_in_rank.scalar_type() == torch::kBool);
if (cached_mode) {
EP_HOST_ASSERT(cached_rank_prefix_matrix->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_channel_prefix_matrix->scalar_type() == torch::kInt32);
} else {
EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32);
}
// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0);
EP_HOST_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous());
EP_HOST_ASSERT(is_token_in_rank.size(0) == x.size(0) and
is_token_in_rank.size(1) == num_ranks);
if (cached_mode) {
EP_HOST_ASSERT(cached_rank_prefix_matrix->dim() == 2 and
cached_rank_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_rank_prefix_matrix->size(0) == num_ranks and
cached_rank_prefix_matrix->size(1) == num_ranks);
EP_HOST_ASSERT(cached_channel_prefix_matrix->dim() == 2 and
cached_channel_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_channel_prefix_matrix->size(0) == num_ranks and
cached_channel_prefix_matrix->size(1) == num_channels);
} else {
EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and
num_tokens_per_expert->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0);
EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS);
EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and
num_tokens_per_rank->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks);
}
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_experts = cached_mode ? 0 : static_cast<int>(num_tokens_per_expert->size(0)),
num_local_experts = num_experts / num_ranks;
// Top-k checks
int num_topk = 0;
int64_t *topk_idx_ptr = nullptr;
float *topk_weights_ptr = nullptr;
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
if (topk_idx.has_value()) {
num_topk = static_cast<int>(topk_idx->size(1));
EP_HOST_ASSERT(num_experts > 0);
EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous());
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));
EP_HOST_ASSERT(num_topk == topk_weights->size(1));
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_idx_ptr = topk_idx->data_ptr<int64_t>();
topk_weights_ptr = topk_weights->data_ptr<float>();
}
// FP8 scales checks
float *x_scales_ptr = nullptr;
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or
x_scales->scalar_type() == torch::kInt);
EP_HOST_ASSERT(x_scales->dim() == 2);
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
x_scales_ptr = static_cast<float *>(x_scales->data_ptr());
scale_token_stride = static_cast<int>(x_scales->stride(0));
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
}
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::hip::setCurrentHIPStreamMasqueradingAsCUDA(comm_stream);
}
// Wait previous tasks to be finished
if (not use_default_stream_as_comm_stream) {
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
}
// Create handles (only return for non-cached mode)
int num_recv_tokens = -1;
auto rank_prefix_matrix = torch::Tensor();
auto channel_prefix_matrix = torch::Tensor();
std::vector<int> num_recv_tokens_per_expert_list;
torch::Tensor num_recv_tokens_per_expert = torch::empty(
{num_local_experts}, torch::TensorOptions().dtype(torch::kLong).device(torch::kCUDA));
// Barrier or send sizes
// To clean: channel start/end offset, head and tail
int num_memset_int = num_channels * num_ranks * 4;
if (cached_mode) {
num_recv_tokens = cached_num_recv_tokens;
rank_prefix_matrix = cached_rank_prefix_matrix.value();
channel_prefix_matrix = cached_channel_prefix_matrix.value();
// Copy rank prefix matrix and clean flags
intranode::cached_notify_dispatch(
rank_prefix_matrix.data_ptr<int>(), num_memset_int, buffer_ptrs_gpu,
barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream);
} else {
rank_prefix_matrix =
torch::empty({num_ranks, num_ranks},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
channel_prefix_matrix =
torch::empty({num_ranks, num_channels},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
// Send sizes
// Meta information:
// - Size prefix by ranks, shaped as `[num_ranks, num_ranks]`
// - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]`
// NOTES: no more token dropping in this version
*moe_recv_counter = -1;
for (int i = 0; i < num_local_experts; ++i)
moe_recv_expert_counter[i] = -1;
EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * sizeof(int) <=
num_nvl_bytes);
intranode::notify_dispatch(
num_tokens_per_rank->data_ptr<int>(), moe_recv_counter_mapped, num_ranks,
num_tokens_per_expert->data_ptr<int>(), moe_recv_expert_counter_mapped,
num_recv_tokens_per_expert.data_ptr<int64_t>(), num_experts, num_tokens,
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
rank_prefix_matrix.data_ptr<int>(), num_memset_int, expert_alignment, buffer_ptrs_gpu,
barrier_signal_ptrs_gpu, rank, comm_stream, num_channels);
if (num_worst_tokens > 0) {
// No CPU sync, just allocate the worst case
num_recv_tokens = num_worst_tokens;
// Must be forward with top-k stuffs
EP_HOST_ASSERT(topk_idx.has_value());
EP_HOST_ASSERT(topk_weights.has_value());
} else {
// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
while (true) {
// Read total count
num_recv_tokens = static_cast<int>(*moe_recv_counter);
// Read per-expert count
bool ready = (num_recv_tokens >= 0);
for (int i = 0; i < num_local_experts and ready; ++i)
ready &= moe_recv_expert_counter[i] >= 0;
if (ready)
break;
// Timeout check
if (std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::high_resolution_clock::now() - start_time)
.count() > NUM_CPU_TIMEOUT_SECS)
throw std::runtime_error("DeepEP error: CPU recv timeout");
}
num_recv_tokens_per_expert_list = std::vector<int>(
moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
}
}
// Allocate new tensors
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
auto recv_src_idx = torch::empty(
{num_recv_tokens}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
auto recv_topk_idx = std::optional<torch::Tensor>(),
recv_topk_weights = std::optional<torch::Tensor>(),
recv_x_scales = std::optional<torch::Tensor>();
auto recv_channel_prefix_matrix =
torch::empty({num_ranks, num_channels},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
auto send_head = torch::empty({num_tokens, num_ranks},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
// Assign pointers
int64_t *recv_topk_idx_ptr = nullptr;
float *recv_topk_weights_ptr = nullptr;
float *recv_x_scales_ptr = nullptr;
if (topk_idx.has_value()) {
recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options());
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_idx_ptr = recv_topk_idx->data_ptr<int64_t>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
if (x_scales.has_value()) {
recv_x_scales = x_scales->dim() == 1
? torch::empty({num_recv_tokens}, x_scales->options())
: torch::empty({num_recv_tokens, num_scales}, x_scales->options());
recv_x_scales_ptr = static_cast<float *>(recv_x_scales->data_ptr());
}
// Dispatch
EP_HOST_ASSERT(num_ranks * num_ranks * sizeof(int) + // Size prefix matrix
num_channels * num_ranks * sizeof(int) + // Channel start offset
num_channels * num_ranks * sizeof(int) + // Channel end offset
num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
hidden * recv_x.element_size() + // Data buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
sizeof(int) + // Source index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
num_topk * sizeof(int64_t) + // Top-k index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
num_topk * sizeof(float) + // Top-k weight buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
sizeof(float) * num_scales // FP8 scale buffer
<= num_nvl_bytes);
intranode::dispatch(
recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr<int>(), recv_topk_idx_ptr,
recv_topk_weights_ptr, recv_channel_prefix_matrix.data_ptr<int>(),
send_head.data_ptr<int>(), x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(), num_tokens,
num_worst_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk,
num_experts, num_scales, scale_token_stride, scale_hidden_stride, buffer_ptrs_gpu, rank,
num_ranks, comm_stream, config.num_sms, config.num_max_nvl_chunked_send_tokens,
config.num_max_nvl_chunked_recv_tokens);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto &t : {x, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x,
recv_src_idx, recv_channel_prefix_matrix, send_head}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto &to :
{x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert,
cached_channel_prefix_matrix, cached_rank_prefix_matrix, recv_topk_idx,
recv_topk_weights, recv_x_scales}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
if (not use_default_stream_as_comm_stream) {
stream_wait(compute_stream, comm_stream);
}
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::hip::setCurrentHIPStreamMasqueradingAsCUDA(compute_stream);
// Return values
return {recv_x,
recv_x_scales,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
num_recv_tokens_per_expert,
rank_prefix_matrix,
channel_prefix_matrix,
recv_channel_prefix_matrix,
recv_src_idx,
send_head,
event};
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
Buffer::intranode_combine(const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &bias_0,
const std::optional<torch::Tensor> &bias_1, const torch::Tensor &src_idx,
const torch::Tensor &rank_prefix_matrix,
const torch::Tensor &channel_prefix_matrix,
const torch::Tensor &send_head, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream) {
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT(src_idx.dim() == 1 and src_idx.is_contiguous() and
src_idx.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and
send_head.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and
rank_prefix_matrix.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and
channel_prefix_matrix.is_contiguous() and
channel_prefix_matrix.scalar_type() == torch::kInt32);
// One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for
// receiving.
EP_HOST_ASSERT(config.num_sms % 2 == 0);
int num_channels = config.num_sms / 2;
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_recv_tokens = static_cast<int>(send_head.size(0));
EP_HOST_ASSERT(src_idx.size(0) == num_tokens);
EP_HOST_ASSERT(send_head.size(1) == num_ranks);
EP_HOST_ASSERT(rank_prefix_matrix.size(0) == num_ranks and
rank_prefix_matrix.size(1) == num_ranks);
EP_HOST_ASSERT(channel_prefix_matrix.size(0) == num_ranks and
channel_prefix_matrix.size(1) == num_channels);
EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0);
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::hip::setCurrentHIPStreamMasqueradingAsCUDA(comm_stream);
}
// Wait previous tasks to be finished
if (not use_default_stream_as_comm_stream) {
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
}
int num_topk = 0;
auto recv_topk_weights = std::optional<torch::Tensor>();
float *topk_weights_ptr = nullptr;
float *recv_topk_weights_ptr = nullptr;
if (topk_weights.has_value()) {
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(topk_weights->size(0) == num_tokens);
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
num_topk = static_cast<int>(topk_weights->size(1));
topk_weights_ptr = topk_weights->data_ptr<float>();
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
// Launch barrier and reset queue head and tail
EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes);
intranode::cached_notify_combine(
buffer_ptrs_gpu, send_head.data_ptr<int>(), num_channels, num_recv_tokens,
num_channels * num_ranks * 2, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream);
// Assign bias pointers
auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});
void *bias_ptrs[2] = {nullptr, nullptr};
for (int i = 0; i < 2; ++i)
if (bias_opts[i].has_value()) {
auto bias = bias_opts[i].value();
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden);
bias_ptrs[i] = bias.data_ptr();
}
// Combine data
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
hidden * x.element_size() + // Data buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
sizeof(int) + // Source index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens *
num_topk * sizeof(float) // Top-k weight buffer
<= num_nvl_bytes);
intranode::combine(
at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), recv_x.data_ptr(),
recv_topk_weights_ptr, x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1],
src_idx.data_ptr<int>(), rank_prefix_matrix.data_ptr<int>(),
channel_prefix_matrix.data_ptr<int>(), send_head.data_ptr<int>(), num_tokens,
num_recv_tokens, hidden, num_topk, buffer_ptrs_gpu, rank, num_ranks, comm_stream,
config.num_sms, config.num_max_nvl_chunked_send_tokens,
config.num_max_nvl_chunked_recv_tokens);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto &t : {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto &to : {topk_weights, recv_topk_weights, bias_0, bias_1}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
if (not use_default_stream_as_comm_stream) {
stream_wait(compute_stream, comm_stream);
}
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::hip::setCurrentHIPStreamMasqueradingAsCUDA(compute_stream);
return {recv_x, recv_topk_weights, event};
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor,
std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<EventHandle>>
Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx,
const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank,
const std::optional<torch::Tensor> &num_tokens_per_rdma_rank,
const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor> &cached_rdma_channel_prefix_matrix,
const std::optional<torch::Tensor> &cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor> &cached_gbl_channel_prefix_matrix,
const std::optional<torch::Tensor> &cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream) {
#ifndef DISABLE_ROCSHMEM
// In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks,
// which can be quite long. If users of DeepEP need to execute other Python code on other
// threads, such as KV transfer, their code will get stuck due to GIL unless we release GIL
// here.
pybind11::gil_scoped_release release;
const int num_channels = config.num_sms / 3;
EP_HOST_ASSERT(config.num_sms % 3 == 0);
EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS);
bool cached_mode = cached_rdma_channel_prefix_matrix.has_value();
if (cached_mode) {
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value());
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value());
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value());
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value());
} else {
EP_HOST_ASSERT(num_tokens_per_rank.has_value());
EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value());
EP_HOST_ASSERT(num_tokens_per_expert.has_value());
}
// Type checks
if (cached_mode) {
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == torch::kInt32);
} else {
EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32);
}
// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0);
if (cached_mode) {
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->dim() == 2 and
cached_rdma_channel_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and
cached_rdma_channel_prefix_matrix->size(1) == num_channels);
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and
cached_recv_rdma_rank_prefix_sum->is_contiguous());
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks);
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 and
cached_gbl_channel_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and
cached_gbl_channel_prefix_matrix->size(1) == num_channels);
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and
cached_recv_gbl_rank_prefix_sum->is_contiguous());
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks);
} else {
EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and
num_tokens_per_rank->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 and
num_tokens_per_rdma_rank->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and
num_tokens_per_expert->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks);
EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks);
EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0);
EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS);
}
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1)),
hidden_int4 = static_cast<int>(x.size(1) * x.element_size() / sizeof(int4));
auto num_experts = cached_mode ? 0 : static_cast<int>(num_tokens_per_expert->size(0)),
num_local_experts = num_experts / num_ranks;
// Top-k checks
int num_topk = 0;
int64_t *topk_idx_ptr = nullptr;
float *topk_weights_ptr = nullptr;
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
if (topk_idx.has_value()) {
num_topk = static_cast<int>(topk_idx->size(1));
EP_HOST_ASSERT(num_experts > 0);
EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous());
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));
EP_HOST_ASSERT(num_topk == topk_weights->size(1));
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_idx_ptr = topk_idx->data_ptr<int64_t>();
topk_weights_ptr = topk_weights->data_ptr<float>();
}
// FP8 scales checks
float *x_scales_ptr = nullptr;
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or
x_scales->scalar_type() == torch::kInt);
EP_HOST_ASSERT(x_scales->dim() == 2);
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
x_scales_ptr = static_cast<float *>(x_scales->data_ptr());
scale_token_stride = static_cast<int>(x_scales->stride(0));
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
}
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::hip::setCurrentHIPStreamMasqueradingAsCUDA(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
// Create handles (only return for non-cached mode)
int num_recv_tokens = -1, num_rdma_recv_tokens = -1;
auto rdma_channel_prefix_matrix = torch::Tensor();
auto recv_rdma_rank_prefix_sum = torch::Tensor();
auto gbl_channel_prefix_matrix = torch::Tensor();
auto recv_gbl_rank_prefix_sum = torch::Tensor();
std::vector<int> num_recv_tokens_per_expert_list;
// Barrier or send sizes
if (cached_mode) {
num_recv_tokens = cached_num_recv_tokens;
num_rdma_recv_tokens = cached_num_rdma_recv_tokens;
rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value();
recv_rdma_rank_prefix_sum = cached_recv_rdma_rank_prefix_sum.value();
gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value();
recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value();
// Just a barrier and clean flags
internode::cached_notify(
hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, 0, nullptr,
nullptr, nullptr, nullptr, rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank,
comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, true, low_latency_mode);
} else {
rdma_channel_prefix_matrix =
torch::empty({num_rdma_ranks, num_channels},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
recv_rdma_rank_prefix_sum = torch::empty(
{num_rdma_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
gbl_channel_prefix_matrix =
torch::empty({num_ranks, num_channels},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
recv_gbl_rank_prefix_sum = torch::empty(
{num_ranks}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
// Send sizes
*moe_recv_counter = -1, *moe_recv_rdma_counter = -1;
for (int i = 0; i < num_local_experts; ++i)
moe_recv_expert_counter[i] = -1;
internode::notify_dispatch(
num_tokens_per_rank->data_ptr<int>(), moe_recv_counter_mapped, num_ranks,
num_tokens_per_rdma_rank->data_ptr<int>(), moe_recv_rdma_counter_mapped,
num_tokens_per_expert->data_ptr<int>(), moe_recv_expert_counter_mapped, num_experts,
is_token_in_rank.data_ptr<bool>(), num_tokens, num_channels, hidden_int4, num_scales,
num_topk, expert_alignment, rdma_channel_prefix_matrix.data_ptr<int>(),
recv_rdma_rank_prefix_sum.data_ptr<int>(), gbl_channel_prefix_matrix.data_ptr<int>(),
recv_gbl_rank_prefix_sum.data_ptr<int>(), rdma_buffer_ptr,
config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu,
config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes,
low_latency_mode);
// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
while (true) {
// Read total count
num_recv_tokens = static_cast<int>(*moe_recv_counter);
num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);
// Read per-expert count
bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0);
for (int i = 0; i < num_local_experts and ready; ++i)
ready &= moe_recv_expert_counter[i] >= 0;
if (ready)
break;
// Timeout check
if (std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::high_resolution_clock::now() - start_time)
.count() > NUM_CPU_TIMEOUT_SECS) {
printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank,
num_recv_tokens, num_rdma_recv_tokens);
for (int i = 0; i < num_local_experts; ++i)
printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]);
throw std::runtime_error("DeepEP error: timeout (dispatch CPU)");
}
}
num_recv_tokens_per_expert_list =
std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
}
// Allocate new tensors
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
auto recv_topk_idx = std::optional<torch::Tensor>(),
recv_topk_weights = std::optional<torch::Tensor>(),
recv_x_scales = std::optional<torch::Tensor>();
auto recv_src_meta = std::optional<torch::Tensor>();
auto recv_rdma_channel_prefix_matrix = std::optional<torch::Tensor>();
auto recv_gbl_channel_prefix_matrix = std::optional<torch::Tensor>();
auto send_rdma_head = std::optional<torch::Tensor>();
auto send_nvl_head = std::optional<torch::Tensor>();
if (not cached_mode) {
recv_src_meta = torch::empty(
{num_recv_tokens, internode::get_source_meta_bytes()},
torch::TensorOptions().dtype(torch::kByte).device(torch::kCUDA));
recv_rdma_channel_prefix_matrix =
torch::empty({num_rdma_ranks, num_channels},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
recv_gbl_channel_prefix_matrix =
torch::empty({num_ranks, num_channels},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
send_rdma_head =
torch::empty({num_tokens, num_rdma_ranks},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
send_nvl_head =
torch::empty({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS},
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA));
}
// Assign pointers
int64_t *recv_topk_idx_ptr = nullptr;
float *recv_topk_weights_ptr = nullptr;
float *recv_x_scales_ptr = nullptr;
if (topk_idx.has_value()) {
recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options());
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_idx_ptr = recv_topk_idx->data_ptr<int64_t>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
if (x_scales.has_value()) {
recv_x_scales = x_scales->dim() == 1
? torch::empty({num_recv_tokens}, x_scales->options())
: torch::empty({num_recv_tokens, num_scales}, x_scales->options());
recv_x_scales_ptr = static_cast<float *>(recv_x_scales->data_ptr());
}
// Launch data dispatch
// NOTES: the buffer size checks are moved into the `.cu` file
internode::dispatch(
recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr,
cached_mode ? nullptr : recv_src_meta->data_ptr(), x.data_ptr(), x_scales_ptr, topk_idx_ptr,
topk_weights_ptr, cached_mode ? nullptr : send_rdma_head->data_ptr<int>(),
cached_mode ? nullptr : send_nvl_head->data_ptr<int>(),
cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr<int>(),
cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr<int>(),
rdma_channel_prefix_matrix.data_ptr<int>(), recv_rdma_rank_prefix_sum.data_ptr<int>(),
gbl_channel_prefix_matrix.data_ptr<int>(), recv_gbl_rank_prefix_sum.data_ptr<int>(),
is_token_in_rank.data_ptr<bool>(), num_tokens, hidden_int4, num_scales, num_topk,
num_experts, scale_token_stride, scale_hidden_stride, rdma_buffer_ptr,
config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens,
config.num_max_nvl_chunked_recv_tokens, rank, num_ranks, cached_mode, comm_stream,
num_channels, low_latency_mode);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto &t :
{x, is_token_in_rank, recv_x, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum,
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto &to :
{x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_rdma_rank,
num_tokens_per_expert, cached_rdma_channel_prefix_matrix,
cached_recv_rdma_rank_prefix_sum, cached_gbl_channel_prefix_matrix,
cached_recv_gbl_rank_prefix_sum, recv_topk_idx, recv_topk_weights, recv_x_scales,
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, send_rdma_head,
send_nvl_head, recv_src_meta}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::hip::setCurrentHIPStreamMasqueradingAsCUDA(compute_stream);
// Return values
return {recv_x,
recv_x_scales,
recv_topk_idx,
recv_topk_weights,
num_recv_tokens_per_expert_list,
rdma_channel_prefix_matrix,
gbl_channel_prefix_matrix,
recv_rdma_channel_prefix_matrix,
recv_rdma_rank_prefix_sum,
recv_gbl_channel_prefix_matrix,
recv_gbl_rank_prefix_sum,
recv_src_meta,
send_rdma_head,
send_nvl_head,
event};
#else
EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install rocSHMEM by "
"following docs/install_dependencies.md");
return {};
#endif
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
Buffer::internode_combine(
const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &bias_0, const std::optional<torch::Tensor> &bias_1,
const torch::Tensor &src_meta, const torch::Tensor &is_combined_token_in_rank,
const torch::Tensor &rdma_channel_prefix_matrix, const torch::Tensor &rdma_rank_prefix_sum,
const torch::Tensor &gbl_channel_prefix_matrix, const torch::Tensor &combined_rdma_head,
const torch::Tensor &combined_nvl_head, const Config &config,
std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream) {
#ifndef DISABLE_ROCSHMEM
const int num_channels = config.num_sms / 3;
EP_HOST_ASSERT(config.num_sms % 3 == 0);
// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT(src_meta.dim() == 2 and src_meta.is_contiguous() and
src_meta.scalar_type() == torch::kByte);
EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and
is_combined_token_in_rank.is_contiguous() and
is_combined_token_in_rank.scalar_type() == torch::kBool);
EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and
rdma_channel_prefix_matrix.is_contiguous() and
rdma_channel_prefix_matrix.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and
rdma_rank_prefix_sum.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and
gbl_channel_prefix_matrix.is_contiguous() and
gbl_channel_prefix_matrix.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and
combined_rdma_head.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and
combined_nvl_head.scalar_type() == torch::kInt32);
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1)),
hidden_int4 = static_cast<int>(x.size(1) * x.element_size() / sizeof(int4));
auto num_combined_tokens = static_cast<int>(is_combined_token_in_rank.size(0));
EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0);
EP_HOST_ASSERT(src_meta.size(1) ==
internode::get_source_meta_bytes());
EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks);
EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks and
rdma_channel_prefix_matrix.size(1) == num_channels);
EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks);
EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks and
gbl_channel_prefix_matrix.size(1) == num_channels);
EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and
combined_rdma_head.size(0) == num_combined_tokens and
combined_rdma_head.size(1) == num_rdma_ranks);
EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and
combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS);
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::hip::setCurrentHIPStreamMasqueradingAsCUDA(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
// Top-k checks
int num_topk = 0;
auto combined_topk_weights = std::optional<torch::Tensor>();
float *topk_weights_ptr = nullptr;
float *combined_topk_weights_ptr = nullptr;
if (topk_weights.has_value()) {
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(topk_weights->size(0) == num_tokens);
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
num_topk = static_cast<int>(topk_weights->size(1));
topk_weights_ptr = topk_weights->data_ptr<float>();
combined_topk_weights =
torch::empty({num_combined_tokens, num_topk}, topk_weights->options());
combined_topk_weights_ptr = combined_topk_weights->data_ptr<float>();
}
// Extra check for avoid-dead-lock design
EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <=
config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks);
// Launch barrier and reset queue head and tail
internode::cached_notify(
hidden_int4, 0, 0, num_topk, num_ranks, num_channels, num_combined_tokens,
combined_rdma_head.data_ptr<int>(), rdma_channel_prefix_matrix.data_ptr<int>(),
rdma_rank_prefix_sum.data_ptr<int>(), combined_nvl_head.data_ptr<int>(), rdma_buffer_ptr,
config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu,
config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes,
false, low_latency_mode);
// Assign bias pointers
auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});
void *bias_ptrs[2] = {nullptr, nullptr};
for (int i = 0; i < 2; ++i)
if (bias_opts[i].has_value()) {
EP_HOST_ASSERT(false and "bias is not supported in internode combine");
auto bias = bias_opts[i].value();
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
EP_HOST_ASSERT(bias.size(0) == num_combined_tokens and bias.size(1) == hidden);
bias_ptrs[i] = bias.data_ptr();
}
// Launch data combine
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
internode::combine(
at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), combined_x.data_ptr(),
combined_topk_weights_ptr, is_combined_token_in_rank.data_ptr<bool>(), x.data_ptr(),
topk_weights_ptr, bias_ptrs[0], bias_ptrs[1], combined_rdma_head.data_ptr<int>(),
combined_nvl_head.data_ptr<int>(), src_meta.data_ptr(),
rdma_channel_prefix_matrix.data_ptr<int>(), rdma_rank_prefix_sum.data_ptr<int>(),
gbl_channel_prefix_matrix.data_ptr<int>(), num_tokens, num_combined_tokens, hidden,
num_topk, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens,
config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu,
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, rank,
num_ranks, comm_stream, num_channels, low_latency_mode);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto &t : {x, src_meta, is_combined_token_in_rank, rdma_channel_prefix_matrix,
rdma_rank_prefix_sum, gbl_channel_prefix_matrix, combined_x,
combined_rdma_head, combined_nvl_head}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto &to : {topk_weights, combined_topk_weights, bias_0, bias_1}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::hip::setCurrentHIPStreamMasqueradingAsCUDA(compute_stream);
// Return values
return {combined_x, combined_topk_weights, event};
#else
EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install rocSHMEM by "
"following docs/install_dependencies.md");
return {};
#endif
}
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden,
int num_experts) {
EP_HOST_ASSERT(false and "not support low latency");
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor,
std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_idx,
const std::optional<torch::Tensor> &cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor> &dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8,
bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook) {
EP_HOST_ASSERT(false and "not support low latency");
return {};
}
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_combine(const torch::Tensor &x, const torch::Tensor &topk_idx,
const torch::Tensor &topk_weights, const torch::Tensor &src_info,
const torch::Tensor &layout_range,
const std::optional<torch::Tensor> &combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor> &out) {
EP_HOST_ASSERT(false and "not support low latency");
return {};
}
torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank,
int hidden, int num_experts) const {
EP_HOST_ASSERT(false and "not support low latency");
return {};
}
} // namespace deep_ep
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "DeepEP: an efficient expert-parallel communication library";
pybind11::class_<deep_ep::Config>(m, "Config")
.def(pybind11::init<int, int, int, int, int>(),
py::arg("num_sms") = 20,
py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256,
py::arg("num_max_rdma_chunked_send_tokens") = 6, py::arg("num_max_rdma_chunked_recv_tokens") = 256)
.def("get_nvl_buffer_size_hint", &deep_ep::Config::get_nvl_buffer_size_hint)
.def("get_rdma_buffer_size_hint", &deep_ep::Config::get_rdma_buffer_size_hint);
m.def("get_low_latency_rdma_size_hint", &deep_ep::get_low_latency_rdma_size_hint);
pybind11::class_<deep_ep::EventHandle>(m, "EventHandle")
.def(pybind11::init<>())
.def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait);
pybind11::class_<deep_ep::Buffer>(m, "Buffer")
.def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool>())
.def("is_available", &deep_ep::Buffer::is_available)
.def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks)
.def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank)
.def("get_root_rdma_rank", &deep_ep::Buffer::get_root_rdma_rank)
.def("get_local_device_id", &deep_ep::Buffer::get_local_device_id)
.def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle)
.def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id)
.def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor)
.def("get_comm_stream", &deep_ep::Buffer::get_comm_stream)
.def("sync", &deep_ep::Buffer::sync)
.def("destroy", &deep_ep::Buffer::destroy)
.def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout)
.def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch)
.def("intranode_combine", &deep_ep::Buffer::intranode_combine)
.def("internode_dispatch", &deep_ep::Buffer::internode_dispatch)
.def("internode_combine", &deep_ep::Buffer::internode_combine)
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);
// m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
// m.attr("topk_idx_t") = py::cast(c10::CppTypeToScalarType<deep_ep::topk_idx_t>::value);
}
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "./kernels/configs.cuh" #include "kernels/configs.cuh"
#include "kernels/exception.cuh" #include "kernels/exception.cuh"
#include "config.hpp" #include "config.hpp"
#include "event.hpp" #include "event.hpp"
......
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/types.h>
#include <tuple>
#include <vector>
#include "kernels/configs.cuh"
#include "kernels/exception.cuh"
#include "config_hip.hpp"
#include "event.hpp"
namespace deep_ep {
struct Buffer {
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8");
private:
// Low-latency mode buffer
int low_latency_buffer_idx = 0;
bool low_latency_mode = false;
// NVLink Buffer
int64_t num_nvl_bytes;
void *buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void **buffer_ptrs_gpu = nullptr;
// NVSHMEM Buffer
int64_t num_rdma_bytes;
void *rdma_buffer_ptr = nullptr;
// Device info and communication
int device_id;
int num_device_sms;
int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks;
hipIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
// Stream for communication
at::hip::HIPStreamMasqueradingAsCUDA comm_stream;
// After IPC/NVSHMEM synchronization, this flag will be true
bool available = false;
// Whether explicit `destroy()` is required.
bool explicitly_destroy;
// After `destroy()` be called, this flag will be true
bool destroyed = false;
// Barrier signals
int *barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int **barrier_signal_ptrs_gpu = nullptr;
// Workspace
void *workspace = nullptr;
// Host-side MoE info
volatile int *moe_recv_counter = nullptr;
int *moe_recv_counter_mapped = nullptr;
// Host-side expert-level MoE info
volatile int *moe_recv_expert_counter = nullptr;
int *moe_recv_expert_counter_mapped = nullptr;
// Host-side RDMA-level MoE info
volatile int *moe_recv_rdma_counter = nullptr;
int *moe_recv_rdma_counter_mapped = nullptr;
bool use_default_stream_as_comm_stream = false;
public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool use_default_stream_as_comm_stream);
~Buffer() noexcept(false);
bool is_available() const;
bool is_internode_available() const;
int get_num_rdma_ranks() const;
int get_rdma_rank() const;
int get_root_rdma_rank(bool global) const;
int get_local_device_id() const;
pybind11::bytearray get_local_ipc_handle() const;
pybind11::bytearray get_local_nvshmem_unique_id() const;
torch::Tensor get_local_buffer_tensor(const pybind11::object &dtype, int64_t offset,
bool use_rdma_buffer) const;
torch::Stream get_comm_stream() const;
void sync(const std::vector<int> &device_ids,
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
const std::optional<pybind11::bytearray> &root_unique_id_opt);
void destroy();
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
std::optional<EventHandle>>
get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::optional<EventHandle>>
intranode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx,
const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank,
const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert,
int cached_num_recv_tokens,
const std::optional<torch::Tensor> &cached_rank_prefix_matrix,
const std::optional<torch::Tensor> &cached_channel_prefix_matrix,
int expert_alignment, int num_worst_tokens, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
intranode_combine(const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &bias_0,
const std::optional<torch::Tensor> &bias_1, const torch::Tensor &src_idx,
const torch::Tensor &rank_prefix_matrix,
const torch::Tensor &channel_prefix_matrix, const torch::Tensor &send_head,
const Config &config, std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>,
torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx,
const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank,
const std::optional<torch::Tensor> &num_tokens_per_rdma_rank,
const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor> &cached_rdma_channel_prefix_matrix,
const std::optional<torch::Tensor> &cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor> &cached_gbl_channel_prefix_matrix,
const std::optional<torch::Tensor> &cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_combine(
const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &bias_0, const std::optional<torch::Tensor> &bias_1,
const torch::Tensor &src_meta, const torch::Tensor &is_combined_token_in_rank,
const torch::Tensor &rdma_channel_prefix_matrix, const torch::Tensor &rdma_rank_prefix_sum,
const torch::Tensor &gbl_channel_prefix_matrix, const torch::Tensor &combined_rdma_head,
const torch::Tensor &combined_nvl_head, const Config &config,
std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream);
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden,
int num_experts);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_idx,
const std::optional<torch::Tensor> &cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor> &dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8,
bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor &x, const torch::Tensor &topk_idx,
const torch::Tensor &topk_weights, const torch::Tensor &src_info,
const torch::Tensor &layout_range,
const std::optional<torch::Tensor> &combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor> &out = std::nullopt);
torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank,
int hidden, int num_experts) const;
};
} // namespace deep_ep
#include "hip/hip_runtime.h"
#include "buffer.cuh" #include "buffer.cuh"
#include "configs.cuh" #include "configs.cuh"
#include "launch.cuh" #include "launch.cuh"
...@@ -8,11 +9,11 @@ ...@@ -8,11 +9,11 @@
#include <rocshmem/rocshmem.hpp> #include <rocshmem/rocshmem.hpp>
// TODO: fix unroll warnings // TODO: fix unroll warnings
#ifdef __clang__ // #ifdef __clang__
#pragma clang diagnostic push // #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed" // #pragma clang diagnostic ignored "-Wpass-failed"
#pragma clang diagnostic ignored "-Wdeprecated-volatile" // #pragma clang diagnostic ignored "-Wdeprecated-volatile"
#endif // __clang__ // #endif // __clang__
namespace deep_ep { namespace deep_ep {
...@@ -58,7 +59,7 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_ ...@@ -58,7 +59,7 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_
__host__ __device__ __forceinline__ std::pair<int, int> __host__ __device__ __forceinline__ std::pair<int, int>
get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) { int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
// Return `int32_t` offset and count to clean lijian // Return `int32_t` offset and count to clean
return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) *
num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) /
sizeof(int), sizeof(int),
...@@ -379,29 +380,31 @@ constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { ...@@ -379,29 +380,31 @@ constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
return num_rdma_ranks < 8 ? num_rdma_ranks : 8; return num_rdma_ranks < 8 ? num_rdma_ranks : 8;
} }
template <bool kLowLatencyMode, int kNumRDMARanks, bool kCachedMode,
template <bool kLowLatencyMode,
int kNumRDMARanks,
bool kCachedMode,
int kNumDispatchRDMASenderWarps, int kNumDispatchRDMASenderWarps,
int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)> int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>
__global__ void __global__ void __launch_bounds__(((1 + NUM_MAX_NVL_PEERS) * kWarpSize), 1)
__launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * kWarpSize), 1) dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights, SourceMeta *recv_src_meta, const int4 *x, const float *x_scales,
SourceMeta *recv_src_meta, const int4 *x, const float *x_scales, const int64_t *topk_idx, const float *topk_weights, int *send_rdma_head,
const int64_t *topk_idx, const float *topk_weights, int *send_rdma_head, int *send_nvl_head, int *recv_rdma_channel_prefix_matrix,
int *send_nvl_head, int *recv_rdma_channel_prefix_matrix, int *recv_gbl_channel_prefix_matrix, const int *rdma_channel_prefix_matrix,
int *recv_gbl_channel_prefix_matrix, const int *rdma_channel_prefix_matrix, const int *recv_rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix,
const int *recv_rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix, const int *recv_gbl_rank_prefix_sum, const bool *is_token_in_rank, int num_tokens,
const int *recv_gbl_rank_prefix_sum, const bool *is_token_in_rank, int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride,
int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
int scale_hidden_stride, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks) {
int num_ranks) {
enum class WarpRole { enum class WarpRole {
kRDMASender, kRDMASender, // 从x写入到RDMA发送缓存
kRDMASenderCoordinator, kRDMASenderCoordinator, // 从RDMA发送缓存写入到远端rdma_rank接收缓存
kRDMAAndNVLForwarder, kRDMAAndNVLForwarder, // 从RDMA接收缓存转写到ipc nvl缓存
kForwarderCoordinator, kForwarderCoordinator, // 向远端RDMA确认接收
kNVLReceivers kNVLReceivers // 从nvl缓存写入到recv_x
}; };
__shared__ rocshmem::rocshmem_ctx_t ctx; __shared__ rocshmem::rocshmem_ctx_t ctx;
...@@ -409,348 +412,334 @@ __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * kWarp ...@@ -409,348 +412,334 @@ __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * kWarp
const auto sm_id = static_cast<int>(blockIdx.x); const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize; const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
const auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, const auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
lane_id = get_lane_id(); const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2; channel_id = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
const bool is_forwarder = sm_id % 2 == 0;
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
EP_DEVICE_ASSERT(num_warps == 1 + NUM_MAX_NVL_PEERS);
const auto role_meta = [=]() -> std::pair<WarpRole, int> { const auto role_meta = [=]() -> std::pair<WarpRole, int> {
if (is_forwarder) { if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
if (warp_id < NUM_MAX_NVL_PEERS) { if(warp_id < kNumDispatchRDMASenderWarps) {
return {WarpRole::kRDMASender, -1};
} else if(warp_id == kNumDispatchRDMASenderWarps) {
return {WarpRole::kRDMASenderCoordinator, -1};
}
} else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
if(warp_id < NUM_MAX_NVL_PEERS) {
return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
} else { } else {
return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS}; return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};
} }
} else if (warp_id < kNumDispatchRDMASenderWarps) {
return {WarpRole::kRDMASender, -1};
} else if (warp_id == kNumDispatchRDMASenderWarps) {
return {WarpRole::kRDMASenderCoordinator, -1};
} else { } else {
return {WarpRole::kNVLReceivers, return {WarpRole::kNVLReceivers, (warp_id + channel_id + 1) % NUM_MAX_NVL_PEERS};
(warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS};
} }
}(); }();
auto warp_role = role_meta.first; auto warp_role = role_meta.first;
auto target_rank = role_meta.second; // Not applicable for RDMA senders auto target_rank = role_meta.second; // Not applicable for RDMA senders
EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS); // if(lane_id==0){
// printf("tid=%d, bid=%d, warp_role=%d\n", threadIdx.x, blockIdx.x, warp_role);
// Data checks // }
EP_DEVICE_ASSERT(num_topk <= kWarpSize);
// RDMA symmetric layout // RDMA symmetric layout
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), auto hidden_bytes = hidden_int4 * sizeof(int4);
"Invalid number of NVL peers"); auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);
auto hidden_bytes = hidden_int4 * sizeof(int4); auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto num_bytes_per_rdma_token = auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk); auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_data = SymBuffer<int8_t>( auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks,
channel_id, num_channels);
auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2,
kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head =
SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail =
SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
// NVL buffer layouts // NVL buffer layouts
// NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"
// means "Write for Senders, Read for Receivers"
void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr; void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr;
int rs_wr_rank = 0, ws_rr_rank = 0; int rs_wr_rank = 0, ws_rr_rank = 0;
if (warp_role == WarpRole::kRDMAAndNVLForwarder) if (warp_role == WarpRole::kRDMAAndNVLForwarder)
rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;
rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;
if (warp_role == WarpRole::kNVLReceivers) if (warp_role == WarpRole::kNVLReceivers)
rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
// Allocate buffers // Allocate buffers
auto nvl_channel_x = auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank) auto nvl_channel_x_scales = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
.advance_also(rs_wr_buffer_ptr); auto nvl_channel_topk_idx = AsymBuffer<int>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_src_meta = auto nvl_channel_topk_weights = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
AsymBuffer<SourceMeta>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
channel_id, num_channels, rs_wr_rank) auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
.advance_also(rs_wr_buffer_ptr); auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr);
auto nvl_channel_x_scales = auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales,
NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank)
.advance_also(rs_wr_buffer_ptr);
auto nvl_channel_topk_idx =
AsymBuffer<int>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk,
NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank)
.advance_also(rs_wr_buffer_ptr);
auto nvl_channel_topk_weights =
AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk,
NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank)
.advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_start =
AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id,
num_channels, rs_wr_rank)
.advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_end =
AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id,
num_channels, rs_wr_rank)
.advance_also(rs_wr_buffer_ptr);
auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id,
num_channels, ws_rr_rank)
.advance_also(ws_rr_buffer_ptr);
auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id,
num_channels, rs_wr_rank)
.advance_also(rs_wr_buffer_ptr);
// RDMA sender warp synchronization // RDMA sender warp synchronization
__shared__ volatile int rdma_send_next_token_idx; __shared__ volatile int rdma_send_next_token_idx;
__shared__ volatile int rdma_send_channel_tail[kNumRDMARanks]; __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks];
__shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks]; __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks];
__shared__ volatile int rdma_sender_counter[1];
__shared__ volatile int rdma_forwarder_counter[1];
if (threadIdx.x == 0) {
rdma_sender_counter[0] = 0;
rdma_forwarder_counter[0] = 0;
}
__syncthreads();
auto sync_rdma_sender_smem = [&]() {
if (lane_id == 0) {
// volatile int ret = __hip_atomic_fetch_add(&rdma_sender_counter[0], 1, __ATOMIC_RELAXED,
// __HIP_MEMORY_SCOPE_WORKGROUP);
volatile int ret = atomicAdd((int*)&rdma_sender_counter[0], 1);
}
syncwarp();
while (rdma_sender_counter[0] < (kNumDispatchRDMASenderWarps + 1)) {
}
};
// Forward warp synchronization // NVL and RDMA coordinate Forward warp synchronization
__shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks]; __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];
__shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS]; __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS];
// NOTE: Not sure that __syncthreads() is a suitable replacement
auto sync_forwarder_smem = [&]() { // Place the main logic of your kernel here, using the parameters above.
if (lane_id == 0) { if(warp_role == WarpRole::kRDMASender) {
// volatile int ret = __hip_atomic_fetch_add( /*
// &rdma_forwarder_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); 这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
volatile int ret = atomicAdd((int*)&rdma_forwarder_counter[0], 1); 它首先获取当前通道的任务范围,然后清理共享内存,接着计算并发送本通道中的令牌数量。
然后,它遍历所有的令牌,读取每个令牌的RDMA秩的存在性,获取顺序锁,计算下一个尾部位置,存储RDMA头部,更新最后一个令牌尾部,释放顺序锁,并广播尾部位置。
最后,它复制相关的数据到对称发送缓冲区。
kRDMASender主要目的是将发送信息x, x_scale,source_meta, topk_idx, topk_weight等信息填充进入rdma发送缓存,
期间要同步warp直接对token的依序操作,以及和kForwarderCoordinator, kRDMASenderCoordinator内存同步。
同时在复制操作时, 使用ld.global.nc.L1::no_allocate.L2::256B, st.global.L1::no_allocate减少L1/L2缓存使用。
*/
// 获取任务范围
int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// 清理共享内存
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA秩数量");
if(warp_id == 0 && lane_id == 0) {
rdma_send_next_token_idx = token_start_idx;
} }
syncwarp(); if(warp_id == 0 && lane_id < kNumRDMARanks) {
while (rdma_forwarder_counter[0] < (NUM_MAX_NVL_PEERS + 1)) { rdma_send_channel_tail[lane_id] = 0;
rdma_send_channel_next_tail[lane_id] = 0;
} }
};
if (warp_role == WarpRole::kRDMASender) { // 发送本通道中的令牌数量,通过 `-value - 1` 表示
// Get tasks EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= kWarpSize, "无效的NVL对等体数量");
int token_start_idx, token_end_idx; // 对于每个目标RDMA秩,以warp为单位进行迭代。计算发送缓冲区的值,并存储在rdma_channel_meta.send_buffer中
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, // 用于填充rdma_channel_meta.send_buffer本节点发送到远端rank, rdma_rank的起始index和结束index
token_end_idx); for(int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank);
// Clean shared memory
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA ranks");
(warp_id == 0 and lane_id == 0) ? (rdma_send_next_token_idx = token_start_idx) : 0;
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0;
// Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= kWarpSize,
"Invalid number of NVL peers");
for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks;
dst_rdma_rank += kNumDispatchRDMASenderWarps) {
if (lane_id < NUM_MAX_NVL_PEERS) { if (lane_id < NUM_MAX_NVL_PEERS) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
-(channel_id == 0
? 0
: gbl_channel_prefix_matrix
[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels +
channel_id - 1]) -
1;
} else if (lane_id < NUM_MAX_NVL_PEERS * 2) { } else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
-gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id -
NUM_MAX_NVL_PEERS) *
num_channels +
channel_id] -
1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2) { } else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
-(channel_id == 0 ? 0
: rdma_channel_prefix_matrix[dst_rdma_rank * num_channels +
channel_id - 1]) -
1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
-rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
} }
rocshmem::rocshmem_ctx_int_put_nbi_wave( syncwarp();
if (dst_rdma_rank != rdma_rank) {
rocshmem::rocshmem_ctx_int_put_nbi_wave(
ctx, rdma_channel_meta.recv_buffer(rdma_rank), ctx, rdma_channel_meta.recv_buffer(rdma_rank),
rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2, rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
} }
rocshmem::rocshmem_ctx_quiet(ctx); rocshmem::rocshmem_ctx_quiet(ctx);
sync_rdma_sender_smem(); // sync_rdma_sender_smem();
__syncthreads();
// Iterate over tokens and copy into buffer // 遍历令牌并复制到缓冲区
int64_t token_idx; int64_t token_idx;
int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1; int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1;
auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
: rdma_channel_data.send_buffer(lane_id); for(token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; // 读取RDMA秩的存在性
token_idx += kNumDispatchRDMASenderWarps) {
// Read RDMA rank existence
uint64_t is_token_in_rank_uint64 = 0; uint64_t is_token_in_rank_uint64 = 0;
if (lane_id < kNumRDMARanks) if(lane_id < kNumRDMARanks) {
is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t *>( is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS); }
// Acquire sequential lock // 获得处理数据的自旋锁,获得锁后才会处理一些数据信息
while (lane_id == 0 and rdma_send_next_token_idx != token_idx) while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
; // 等待
}
syncwarp(); syncwarp();
// Acquire next tail // 获取下一个尾部位置
int rdma_tail_idx = -1; int rdma_tail_idx = -1;
if (is_token_in_rank_uint64 != 0) { if(is_token_in_rank_uint64 != 0) {
rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++; rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++;
while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens)
cached_rdma_channel_head = // 与kForwarderCoordinator相互配合,调节发送数据的频率
static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id))); while(rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
}
} }
syncwarp(); syncwarp();
// Store RDMA head for combine // 存储RDMA头部以供合并
if (lane_id < kNumRDMARanks and not kCachedMode) if(lane_id < kNumRDMARanks && !kCachedMode) {
send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
}
// Update last token tail // 更新最后一个令牌尾部
if (last_rdma_tail_idx >= 0) if(last_rdma_tail_idx >= 0) {
st_release_cta(const_cast<const int *>(rdma_send_channel_tail + lane_id), st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
last_rdma_tail_idx + 1); }
last_rdma_tail_idx = rdma_tail_idx; last_rdma_tail_idx = rdma_tail_idx;
// Release sequential lock // 释放顺序锁
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; if(lane_id == 0) {
rdma_send_next_token_idx += 1;
}
// Broadcast tails // 广播尾部位置
SourceMeta src_meta; SourceMeta src_meta;
int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
void *dst_send_buffers[kNumTopkRDMARanks]; void* dst_send_buffers[kNumTopkRDMARanks];
#pragma unroll /*
for (int i = 0, slot_idx; i < kNumRDMARanks; ++i) 该for循环主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作
if ((slot_idx = shfl_sync(rdma_tail_idx, i)) >= 0) { */
slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens; #pragma unroll
topk_ranks[num_topk_ranks] = i; for(int i = 0, slot_idx; i < kNumRDMARanks; ++i) {
// 使用__shfl_sync函数在warp内同步并广播rdma_tail_idx的值
if((slot_idx = shfl_sync(rdma_tail_idx, i)) >= 0) {
// warp 所有线程参与,rdma_tail_idx默认为-1, 只有对应rdma rank需要发送时, rdma_tail_idx才会>=0
// 计算slot_idx在接收缓冲区中的位置
slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;
// 存储当前RDMA秩到topk_ranks数组中
topk_ranks[num_topk_ranks] = i;
// 广播is_token_in_rank_uint64的值到所有线程,并解释为布尔数组
auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i); auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i);
auto recv_is_token_in_rank_values = auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);
reinterpret_cast<const bool *>(&recv_is_token_in_rank_uint64);
if (lane_id == num_topk_ranks) // 如果当前lane_id等于num_topk_ranks,则更新src_meta
if(lane_id == num_topk_ranks) {
src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values); src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);
dst_send_buffers[num_topk_ranks++] = }
reinterpret_cast<uint8_t *>(broadcast(send_buffer, i)) +
slot_idx * num_bytes_per_rdma_token; // 计算目标发送缓冲区的地址,并存储在dst_send_buffers数组中
// 获取到发送地址, num_topk_ranks-1 是需要发送的ranks数
dst_send_buffers[num_topk_ranks++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token;
} }
}
EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks); EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);
// Copy `x` into symmetric send buffer // 复制 `x` 到对称发送缓冲区
auto st_broadcast = [=](const int key, const int4 &value) { auto st_broadcast = [=](const int key, const int4& value) {
for (int j = 0; j < num_topk_ranks; ++j) #pragma unroll
st_na_global(reinterpret_cast<int4 *>(dst_send_buffers[j]) + key, value); for(int j = 0; j < num_topk_ranks; ++j) {
st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);
}
}; };
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);
ld_nc_global, st_broadcast); #pragma unroll
for (int i = 0; i < num_topk_ranks; ++i) for(int i = 0; i < num_topk_ranks; ++i) {
dst_send_buffers[i] = reinterpret_cast<int4 *>(dst_send_buffers[i]) + hidden_int4; dst_send_buffers[i] = reinterpret_cast<int4*>(dst_send_buffers[i]) + hidden_int4;
}
// Copy source metadata into symmetric send buffer
if (lane_id < num_topk_ranks)
st_na_global(reinterpret_cast<SourceMeta *>(dst_send_buffers[lane_id]), src_meta);
for (int i = 0; i < num_topk_ranks; ++i) // 复制源元数据到对称发送缓冲区
dst_send_buffers[i] = reinterpret_cast<SourceMeta *>(dst_send_buffers[i]) + 1; if(lane_id < num_topk_ranks) {
st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);
}
#pragma unroll
for(int i = 0; i < num_topk_ranks; ++i) {
dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;
}
// Copy `x_scales` into symmetric send buffer // 复制 `x_scales` 到对称发送缓冲区
for (int i = lane_id; i < num_scales; i += kWarpSize) { #pragma unroll
for(int i = lane_id; i < num_scales; i += kWarpSize) {
auto value = ld_nc_global(x_scales + token_idx * num_scales + i); auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
for (int j = 0; j < num_topk_ranks; ++j)
st_na_global(reinterpret_cast<float *>(dst_send_buffers[j]) + i, value); // auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
// auto value = ld_nc_global(x_scales + offset);
#pragma unroll
for(int j = 0; j < num_topk_ranks; ++j) {
st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
}
}
#pragma unroll
for(int i = 0; i < num_topk_ranks; ++i) {
dst_send_buffers[i] = reinterpret_cast<float*>(dst_send_buffers[i]) + num_scales;
} }
for (int i = 0; i < num_topk_ranks; ++i)
dst_send_buffers[i] = reinterpret_cast<float *>(dst_send_buffers[i]) + num_scales;
// Copy `topk_idx` and `topk_weights` into symmetric send buffer // 复制 `topk_idx` 和 `topk_weights` 到对称发送缓冲区
for (int i = lane_id; i < num_topk * num_topk_ranks; i += kWarpSize) { #pragma unroll
for(int i = lane_id; i < num_topk * num_topk_ranks; i += kWarpSize) {
auto rank_idx = i / num_topk, copy_idx = i % num_topk; auto rank_idx = i / num_topk, copy_idx = i % num_topk;
auto idx_value = auto idx_value = static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));
static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));
auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx); auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx);
st_na_global(reinterpret_cast<int *>(dst_send_buffers[rank_idx]) + copy_idx, st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
idx_value); st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);
st_na_global(reinterpret_cast<float *>(dst_send_buffers[rank_idx]) + num_topk +
copy_idx,
weight_value);
} }
} }
// Epilogue // 结尾部分
// Acquire sequential lock // 获取顺序锁
while (lane_id == 0 and rdma_send_next_token_idx != token_idx) while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
; // 等待
}
syncwarp(); syncwarp();
// Update last token tail // 更新最后一个令牌尾部
if (last_rdma_tail_idx >= 0) if(last_rdma_tail_idx >= 0) {
st_release_cta(const_cast<const int *>(rdma_send_channel_tail + lane_id), st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
last_rdma_tail_idx + 1); }
// Release sequential lock // 释放顺序锁
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; if(lane_id == 0) {
} else if (warp_role == WarpRole::kRDMASenderCoordinator) { rdma_send_next_token_idx += 1;
// NOTES: in case of splitting the issued put at the end of the buffer }
EP_DEVICE_ASSERT( } else if(warp_role == WarpRole::kRDMASenderCoordinator) {
num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); /*
这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
它首先计算每个RDMA秩需要发送的令牌数,然后在所有RDMA秩之间循环,检查是否有令牌需要发送。
如果有,它将计算本次需要发出的令牌数,并发出相应的RDMA发送请求。
最后,它更新相关的尾部位置,以便下次循环时可以正确地计算需要发送的令牌数。
kRDMASenderCoordinator使用了同sm内存一致性(ld.acquire.cta.s32),
nvshmem内存一致性(nvshmem_fence)和原子操作(nvshmemx_signal_op),减少硬同步,提升整体效率。
*/
if(warp_id > kNumDispatchRDMASenderWarps) {
return;
}
// 确保最大接收令牌数可以被最大发送令牌数整除,以避免缓冲区分割问题
EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
// Synchronize shared memory // 同步共享内存,确保所有线程在继续之前都达到了这一点
sync_rdma_sender_smem(); // sync_rdma_sender_smem();
__syncthreads();
// Get number of tokens to send for each RDMA rank // 计算当前通道需要发送的令牌数
int num_tokens_to_send = 0; int num_tokens_to_send = 0;
if (lane_id < kNumRDMARanks) { if(lane_id < kNumRDMARanks) {
num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id]; num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id];
if (channel_id > 0) if(channel_id > 0)
num_tokens_to_send -= num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];
rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];
} }
// Iterate all RDMA ranks // 记录上次发出的尾部位置
int last_issued_tail = 0; int last_issued_tail = 0;
while (__any_sync(kFullWarpMask, num_tokens_to_send > 0)) { // 当有任何RDMA秩需要发送令牌时,继续循环
for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) { while(__any_sync(kFullWarpMask, num_tokens_to_send > 0)) {
int dst_rdma_rank = (i + channel_id) % kNumRDMARanks; for(int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) {
// 计算目标RDMA秩
int dst_rdma_rank = (i + channel_id) % kNumRDMARanks;
// 获取同步后的需要发送的令牌数
synced_num_tokens_to_send = shfl_sync(num_tokens_to_send, dst_rdma_rank); synced_num_tokens_to_send = shfl_sync(num_tokens_to_send, dst_rdma_rank);
if (synced_num_tokens_to_send == 0)
continue;
// Read progress if(synced_num_tokens_to_send == 0)
continue; // 如果没有令牌需要发送,则跳过
// 读取进度
auto synced_last_issued_tail = shfl_sync(last_issued_tail, dst_rdma_rank); auto synced_last_issued_tail = shfl_sync(last_issued_tail, dst_rdma_rank);
auto processed_tail = auto processed_tail = ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank));
ld_acquire_cta(const_cast<const int *>(rdma_send_channel_tail + dst_rdma_rank)); auto num_tokens_processed = processed_tail - synced_last_issued_tail;
auto num_tokens_processed = processed_tail - synced_last_issued_tail;
if (num_tokens_processed != synced_num_tokens_to_send and // 如果处理的令牌数不等于需要发送的令牌数,并且处理的令牌数小于最大发送令牌数,则跳过
num_tokens_processed < num_max_rdma_chunked_send_tokens) if(num_tokens_processed != synced_num_tokens_to_send && num_tokens_processed < num_max_rdma_chunked_send_tokens)
continue; continue;
// Issue RDMA send // 计算本次需要发出的令牌数
auto num_tokens_to_issue = auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens);
min(num_tokens_processed, num_max_rdma_chunked_send_tokens); EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 && num_tokens_to_issue <= synced_num_tokens_to_send);
EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 and
num_tokens_to_issue <= synced_num_tokens_to_send); // 发出RDMA发送请求
if (dst_rdma_rank != rdma_rank) { if(dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
num_max_rdma_chunked_recv_tokens);
rocshmem::rocshmem_ctx_schar_put_nbi_wave( rocshmem::rocshmem_ctx_schar_put_nbi_wave(
ctx, ctx,
rdma_channel_data.recv_buffer(rdma_rank) + rdma_channel_data.recv_buffer(rdma_rank) +
...@@ -761,370 +750,366 @@ __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * kWarp ...@@ -761,370 +750,366 @@ __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * kWarp
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
rocshmem::rocshmem_ctx_quiet(ctx); rocshmem::rocshmem_ctx_quiet(ctx);
} else { } else {
// Lighter fence for local RDMA rank // 对于本地RDMA秩,使用较轻的内存屏障
memory_fence(); memory_fence();
} }
// Update tails // 更新尾部位置
syncwarp(); syncwarp();
if (lane_id == dst_rdma_rank) { if(lane_id == dst_rdma_rank) {
last_issued_tail += num_tokens_to_issue; last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue;
// 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
rocshmem::rocshmem_ctx_ulong_atomic_add( rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, ctx, rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
} }
} }
} } // while(__any(num_tokens_to_send > 0))
} else if (warp_role == WarpRole::kRDMAAndNVLForwarder) { } else if(warp_role == WarpRole::kRDMAAndNVLForwarder) {
// RDMA consumers and NVL producers /*
const auto dst_nvl_rank = target_rank; 这段代码的主要功能是在一个CUDA内核中协调从RDMA消费者到NVL生产者的转发操作。
const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank; 它首先计算目标NVL秩和目标秩,然后等待相关的计数器到达。
const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks); 接着,它检查目标队列是否为空,或者等待一个缓冲区被释放。
const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks); 然后,它找到下一个源RDMA秩,并遍历RDMA缓冲区中的每一个令牌,复制相关的数据到NVL缓冲区。
最后,它同步头部和尾部索引,并标记通道为退役状态。
// Wait counters to arrive */
// RDMA消费者和NVL生产者
const auto dst_nvl_rank = target_rank; // 目标NVL秩
const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank; // 目标秩
const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks); // 目标秩专家开始
const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks); // 目标秩专家结束
// 等待计数器到达
int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0; int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0;
EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize); EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
auto start_time = wall_clock64(); auto start_time = wall_clock64();
if (lane_id < kNumRDMARanks) { if(lane_id < kNumRDMARanks) {
while (true) { while(true) {
auto meta_0 = // 对应于kRDMASender中的数据写入
ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank); auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank); // 是nvl节点的起始地址
auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); // nvl节点的结束地址
NUM_MAX_NVL_PEERS + dst_nvl_rank); auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2); // 本rdma节点的起始地址
auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1); // 本节点的结束地址
NUM_MAX_NVL_PEERS * 2); if(meta_0 < 0 && meta_1 < 0 && meta_2 < 0 && meta_3 < 0) {
auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + // 通知NVL秩
NUM_MAX_NVL_PEERS * 2 + 1);
if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) {
// Notify NVL ranks
int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1; int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;
EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and EP_DEVICE_ASSERT(start_sum >= 0 && end_sum >= 0 && end_sum >= start_sum);
end_sum >= start_sum);
st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1);
-start_sum - 1);
st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1); st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1);
// Save RDMA channel received token count // 保存从RDMA通道接收的令牌计数
src_rdma_channel_prefix = -meta_2 - 1; src_rdma_channel_prefix = -meta_2 - 1;
auto src_rdma_channel_prefix_1 = -meta_3 - 1; auto src_rdma_channel_prefix_1 = -meta_3 - 1;
num_tokens_to_recv_from_rdma = num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; // 是远端 rdma_rank 会发送给当前节点的token数量
src_rdma_channel_prefix_1 - src_rdma_channel_prefix; if(!kCachedMode)
if (not kCachedMode) recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1;
recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] =
src_rdma_channel_prefix_1; src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; // 对应的远端 rdma_rank 的起始index, 存在线程0之中
src_rdma_channel_prefix +=
lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1];
EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0); EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);
break; break;
} }
// Timeout check // 超时检查
long long int elapsed_time = if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
wall_clock64() > start_time ? wall_clock64() - start_time : 0; printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n",
if (elapsed_time > NUM_TIMEOUT_CYCLES) { channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3);
printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, "
"nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1,
meta_2, meta_3);
trap(); trap();
} }
} }
} }
syncwarp(); syncwarp();
// Shift cached head
// 移动缓存的头部
send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank; send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank;
// Wait shared memory to be cleaned // 等待共享内存被清理
sync_forwarder_smem(); // sync_forwarder_smem();
__syncthreads();
// Forward tokens from RDMA buffer // 开始准备处理接受数据,直到所有的数据接受完成。
// NOTES: always start from the local rank // 转发从RDMA缓冲区的令牌
int src_rdma_rank = sm_id % kNumRDMARanks; // 注意:总是从本地秩开始
int src_rdma_rank = sm_id % kNumRDMARanks;
int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0; int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0;
int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0; int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0;
while (__any_sync(kFullWarpMask, num_tokens_to_recv_from_rdma > 0)) { while(__any_sync(kFullWarpMask, num_tokens_to_recv_from_rdma > 0)) {
// Check destination queue emptiness, or wait a buffer to be released // 检查nvl目标队列是否为空,或者等待一个缓冲区被释放
start_time = wall_clock64(); start_time = wall_clock64();
while (lane_id == 0) {
// 用于给kNVLReceivers进行互动,控制数据的传输速度
while(lane_id == 0) {
int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head; int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;
if (num_max_nvl_chunked_recv_tokens - num_used_slots >= if(num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens)
num_max_nvl_chunked_send_tokens)
break; break;
cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer()); cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
// Timeout check // 超时检查
long long int elapsed_time = if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
wall_clock64() > start_time ? wall_clock64() - start_time : 0; printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n",
if (elapsed_time > NUM_TIMEOUT_CYCLES) { channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail);
printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, "
"nvl: %d, dst NVL: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, dst_nvl_rank,
ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail);
trap(); trap();
} }
} }
syncwarp(); syncwarp();
// Find next source RDMA rank (round-robin) // 找到下一个源RDMA秩(轮询)
start_time = wall_clock64(); start_time = wall_clock64();
while (true) { while(true) {
src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks; src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;
if (shfl_sync(num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { if(shfl_sync(num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) {
if (lane_id == src_rdma_rank and if(lane_id == src_rdma_rank && cached_rdma_channel_head == cached_rdma_channel_tail)
cached_rdma_channel_head == cached_rdma_channel_tail) cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
cached_rdma_channel_tail = static_cast<int>(
ld_relaxed_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); if(shfl_sync(cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) {
if (shfl_sync(cached_rdma_channel_tail > cached_rdma_channel_head,
src_rdma_rank))
break; break;
}
} }
// Timeout check // 超时检查
long long int elapsed_time = if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
wall_clock64() > start_time ? wall_clock64() - start_time : 0; printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n",
if (elapsed_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma);
printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, "
"nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: "
"%d\n",
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id,
cached_rdma_channel_head, cached_rdma_channel_tail,
num_tokens_to_recv_from_rdma);
trap(); trap();
} }
} }
auto src_rdma_head = shfl_sync(cached_rdma_channel_head, src_rdma_rank); auto src_rdma_head = shfl_sync(cached_rdma_channel_head, src_rdma_rank);
auto src_rdma_tail = shfl_sync(cached_rdma_channel_tail, src_rdma_rank); auto src_rdma_tail = shfl_sync(cached_rdma_channel_tail, src_rdma_rank);
// Iterate over every token from the RDMA buffer // 遍历RDMA缓冲区中的每一个令牌
for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) { for(int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) {
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
void *shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + // 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入
rdma_slot_idx * num_bytes_per_rdma_token; void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta *>( auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));
reinterpret_cast<int8_t *>(shifted) + hidden_bytes)); if(lane_id == src_rdma_rank) {
lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0; num_tokens_to_recv_from_rdma -= 1;
}
bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
if (lane_id == src_rdma_rank) { if(lane_id == src_rdma_rank) {
auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1; auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1;
rdma_nvl_token_idx += is_in_dst_nvl_rank; rdma_nvl_token_idx += is_in_dst_nvl_rank;
if (not kCachedMode) if(!kCachedMode)
send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head; send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head;
} }
if (not is_in_dst_nvl_rank)
if(!is_in_dst_nvl_rank)
continue; continue;
// Get an empty slot // 获取一个空闲槽位
int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens; int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens;
// Copy data // 复制数据
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
nvl_channel_x.buffer() + dst_slot_idx * hidden_int4, nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
reinterpret_cast<int4 *>(shifted), ld_nc_global, st_na_global); reinterpret_cast<int4*>(shifted),
shifted = reinterpret_cast<int4 *>(shifted) + hidden_int4; ld_nc_global, st_na_global);
shifted = reinterpret_cast<int4*>(shifted) + hidden_int4;
// Copy source meta // 复制源元数据
if (lane_id == 0) if(lane_id == 0)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta); st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
shifted = reinterpret_cast<SourceMeta *>(shifted) + 1; shifted = reinterpret_cast<SourceMeta*>(shifted) + 1;
// Copy `x_scales` // 复制 `x_scales`
UNROLLED_WARP_COPY(1, lane_id, num_scales, UNROLLED_WARP_COPY(1, lane_id, num_scales,
nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales, nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
reinterpret_cast<float *>(shifted), ld_nc_global, st_na_global); reinterpret_cast<float*>(shifted),
shifted = reinterpret_cast<float *>(shifted) + num_scales; ld_nc_global, st_na_global);
shifted = reinterpret_cast<float*>(shifted) + num_scales;
// Copy `topk_idx` and `topk_weights`
// NOTES: do not use `shifted` after this `if`, because only several lanes are // 复制 `topk_idx` 和 `topk_weights`
// shifted if(lane_id < num_topk) {
if (lane_id < num_topk) { // 读取
// Read auto idx_value = ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id);
auto idx_value = ld_nc_global(reinterpret_cast<int *>(shifted) + lane_id); shifted = reinterpret_cast<int*>(shifted) + num_topk;
shifted = reinterpret_cast<int *>(shifted) + num_topk; auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted) + lane_id);
auto weight_value = ld_nc_global(reinterpret_cast<float *>(shifted) + lane_id);
// 转换和写入
// Transform and write idx_value = (idx_value >= dst_rank_expert_begin && idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
idx_value = st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value);
(idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end)
? idx_value - dst_rank_expert_begin
: -1;
st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id,
idx_value);
weight_value = idx_value >= 0 ? weight_value : 0.0f; weight_value = idx_value >= 0 ? weight_value : 0.0f;
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value);
lane_id,
weight_value);
} }
// In case of insufficient NVL buffers, early stopping // 在NVL缓冲区不足的情况下,提前停止
if ((++num_tokens_sent) == num_max_nvl_chunked_send_tokens) if((++num_tokens_sent) == num_max_nvl_chunked_send_tokens)
src_rdma_tail = i + 1; src_rdma_tail = i + 1;
} }
// Sync head index // 同步头部索引
if (lane_id == src_rdma_rank) if(lane_id == src_rdma_rank)
forward_channel_head[dst_nvl_rank][src_rdma_rank] = forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail);
(cached_rdma_channel_head = src_rdma_tail);
// Move tail index // 移动尾部索引,与kNVLReceivers互相通信使用
syncwarp(); syncwarp();
if (lane_id == 0) if(lane_id == 0) {
st_relaxed_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail); st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);
}
} }
// Retired // Retired
syncwarp(); syncwarp();
if (lane_id == 0) if(lane_id == 0) {
forward_channel_retired[dst_nvl_rank] = true; forward_channel_retired[dst_nvl_rank] = true;
} else if (warp_role == WarpRole::kForwarderCoordinator) { }
} else if(warp_role == WarpRole::kForwarderCoordinator) {
/*
这段代码的主要功能是在一个CUDA内核中协调转发器的逻辑。
它首先检查当前warp是否是额外的转发器协调warp,如果是,则直接退出。
然后,它清理共享内存,并初始化转发通道的头部和退役状态。
接着,它进入一个无限循环,在循环中,它找到最小的头部,如果所有的通道都已退役,则退出循环。
否则,它更新远程头部,并进行纳秒级睡眠,以让其他warp工作。
*/
// Extra warps for forwarder coordinator should exit directly // Extra warps for forwarder coordinator should exit directly
if (target_rank > 0) if (warp_id > NUM_MAX_NVL_PEERS)
return; return;
// Forward warp coordinator // 转发warp协调器
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers"); EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
// 清理共享内存
// Clean shared memory EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "无效的NVL对等体数量");
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers"); #pragma unroll
for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += kWarpSize) for(int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += kWarpSize)
forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0; forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0;
if (lane_id < NUM_MAX_NVL_PEERS) if(lane_id < NUM_MAX_NVL_PEERS)
forward_channel_retired[lane_id] = false; forward_channel_retired[lane_id] = false;
sync_forwarder_smem(); // sync_forwarder_smem();
__syncthreads();
int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0; int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0;
while (true) {
// Find minimum head while(true) {
// 找到最小的头部
int min_head = std::numeric_limits<int>::max(); int min_head = std::numeric_limits<int>::max();
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) for(int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
if (not forward_channel_retired[i]) if(!forward_channel_retired[i])
min_head = min(min_head, forward_channel_head[i][target_rdma]); min_head = min(min_head, forward_channel_head[i][target_rdma]);
if (__all_sync(kFullWarpMask, min_head == std::numeric_limits<int>::max()))
if(__all_sync(kFullWarpMask, min_head == std::numeric_limits<int>::max())) {
break; break;
}
// Update remote head // 更新远程头部
if (min_head != std::numeric_limits<int>::max() and if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){
min_head >= last_head + num_max_rdma_chunked_send_tokens and
lane_id < kNumRDMARanks) {
rocshmem::rocshmem_ctx_ulong_atomic_add( rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_head, ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
last_head = min_head; last_head = min_head;
} }
// Nanosleep and let other warps work // 纳秒级睡眠并让其他warp工作 // Nanosleep and let other warps work
__builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64); __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
} }
} else { } else if(warp_role == WarpRole::kNVLReceivers) {
// NVL consumers if(warp_id >= NUM_MAX_NVL_PEERS) {
// Retrieve rank offset from barrier results (each lane's register stores an RDMA rank) return;
}
// Place the main logic of your kernel here, using the parameters above.
// NVL消费者
// 从屏障结果中检索秩偏移(每个通道的寄存器存储一个RDMA秩)
int src_nvl_rank = target_rank, total_offset = 0; int src_nvl_rank = target_rank, total_offset = 0;
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers"); EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) if(lane_id < kNumRDMARanks && lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0)
total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1]; total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];
// Receive channel offsets // 接收通道偏移
int start_offset = 0, end_offset = 0, num_tokens_to_recv; int start_offset = 0, end_offset = 0, num_tokens_to_recv;
auto start_time = wall_clock64(); auto start_time = wall_clock64();
while (lane_id < kNumRDMARanks) {
while(lane_id < kNumRDMARanks) {
start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id); start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id);
end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id); end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id);
if (start_offset < 0 and end_offset < 0) { if(start_offset < 0 && end_offset < 0) {
start_offset = -start_offset - 1, end_offset = -end_offset - 1; start_offset = -start_offset - 1, end_offset = -end_offset - 1;
total_offset += start_offset; total_offset += start_offset;
break; break;
} }
// 超时检查
// Timeout check if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
long long int elapsed_time = printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n",
wall_clock64() > start_time ? wall_clock64() - start_time : 0; channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset);
if (elapsed_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src "
"RDMA: %d, src nvl: %d, start: %d, end: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset,
end_offset);
trap(); trap();
} }
} }
num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset); num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);
// Save for combine usage // 保存以供合并使用
if (lane_id < kNumRDMARanks and not kCachedMode) if(lane_id < kNumRDMARanks && !kCachedMode)
recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset;
num_channels +
channel_id] = total_offset;
syncwarp(); syncwarp();
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
while (num_tokens_to_recv > 0) { while(num_tokens_to_recv > 0) {
// Check channel status by lane 0 // 通过通道0检查通道状态
start_time = wall_clock64(); start_time = wall_clock64();
while (lane_id == 0) { while(lane_id == 0) {
// Ready to copy // 准备复制
if (cached_channel_head_idx != cached_channel_tail_idx) if(cached_channel_head_idx != cached_channel_tail_idx)
break; break;
cached_channel_tail_idx = ld_relaxed_sys_global(nvl_channel_tail.buffer()); cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer());
// 超时检查
// Timeout check if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
long long int elapsed_time = printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n",
wall_clock64() > start_time ? wall_clock64() - start_time : 0; channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx);
if (elapsed_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, "
"src NVL: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx,
cached_channel_tail_idx);
trap(); trap();
} }
} }
// Sync queue tail // 同步队列尾部
cached_channel_tail_idx = shfl_sync(cached_channel_tail_idx, 0); cached_channel_tail_idx = shfl_sync(cached_channel_tail_idx, 0);
// Copy data // 复制数据
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
for (int chunk_idx = 0; chunk_idx < num_recv_tokens; for(int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) {
++chunk_idx, --num_tokens_to_recv) { int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens;
int token_idx_in_buffer = auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);
(cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens; int64_t recv_token_idx = shfl_sync(total_offset, meta.src_rdma_rank);
auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);
int64_t recv_token_idx = shfl_sync(total_offset, meta.src_rdma_rank);
(lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;
// Copy data // 复制数据
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, recv_x + recv_token_idx * hidden_int4, UNROLLED_WARP_COPY(5,
nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4, lane_id,
ld_nc_global, st_na_global); hidden_int4,
recv_x + recv_token_idx * hidden_int4,
// Copy source meta nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4,
if (lane_id == 0 and not kCachedMode) ld_nc_global,
st_na_global);
// 复制源元数据
if(lane_id == 0 && !kCachedMode)
st_na_global(recv_src_meta + recv_token_idx, meta); st_na_global(recv_src_meta + recv_token_idx, meta);
// Copy scales // 复制比例
UNROLLED_WARP_COPY(1, lane_id, num_scales, UNROLLED_WARP_COPY(1,
recv_x_scales + recv_token_idx * num_scales, lane_id,
nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales, num_scales,
ld_nc_global, st_na_global); recv_x_scales + recv_token_idx * num_scales,
nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,
// Copy `topk_idx` and `topk_weights` ld_nc_global,
if (lane_id < num_topk) { st_na_global);
// 复制 `topk_idx` 和 `topk_weights`
if(lane_id < num_topk) {
auto recv_idx = recv_token_idx * num_topk + lane_id; auto recv_idx = recv_token_idx * num_topk + lane_id;
auto buffer_idx = token_idx_in_buffer * num_topk + lane_id; auto buffer_idx = token_idx_in_buffer * num_topk + lane_id;
st_na_global(recv_topk_idx + recv_idx, st_na_global(recv_topk_idx + recv_idx, static_cast<int64_t>(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));
static_cast<int64_t>( st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));
ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));
st_na_global(recv_topk_weights + recv_idx,
ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));
} }
} }
// Move queue // 移动队列
syncwarp(); syncwarp();
if (lane_id == 0) if(lane_id == 0) {
st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
} }
} // while(num_tokens_to_recv > 0)
} }
rocshmem::rocshmem_wg_ctx_destroy(&ctx); rocshmem::rocshmem_wg_ctx_destroy(&ctx);
} }
...@@ -1143,6 +1128,8 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float ...@@ -1143,6 +1128,8 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float
int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels, int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels,
bool low_latency_mode) { bool low_latency_mode) {
constexpr int kNumDispatchRDMASenderWarps = 7; constexpr int kNumDispatchRDMASenderWarps = 7;
// Make sure never OOB
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ #define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \
{ \ { \
...@@ -1170,14 +1157,14 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float ...@@ -1170,14 +1157,14 @@ void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float
EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));
SETUP_LAUNCH_CONFIG(num_channels * 2, SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
(kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * kWarpSize, stream); (1 + NUM_MAX_NVL_PEERS) * kWarpSize, stream);
SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE); SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE #undef DISPATCH_LAUNCH_CASE
} }
template <bool kLowLatencyMode> template <bool kLowLatencyMode>
__global__ void __global__ void __launch_bounds__(1024, 1)
cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset, cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset,
const int nvl_num_int_clean, int *combined_rdma_head, int num_combined_tokens, const int nvl_num_int_clean, int *combined_rdma_head, int num_combined_tokens,
int num_channels, const int *rdma_channel_prefix_matrix, int num_channels, const int *rdma_channel_prefix_matrix,
...@@ -1298,7 +1285,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -1298,7 +1285,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank, int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode) { bool is_cached_dispatch, bool low_latency_mode) {
const int num_threads = std::max(128, kWarpSize * num_channels); const int num_threads = ::max(128, kWarpSize * num_channels);
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
// Get clean meta // Get clean meta
...@@ -1327,56 +1314,58 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to ...@@ -1327,56 +1314,58 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
cpu_rdma_team); cpu_rdma_team);
} }
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, int kWidth, typename ReceiveFn, template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
typename ReceiveTWFn> __device__ int combine_token(bool is_token_in_rank, int head_idx,
__device__ int combine_token(bool is_token_in_rank, int head_idx, int lane_id, int hidden_int4, int lane_id, int hidden_int4, int num_topk,
int num_topk, int4 *combined_row, float *combined_topk_weights, int4* combined_row, float* combined_topk_weights,
int num_max_recv_tokens, const ReceiveFn &recv_fn, int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
const ReceiveTWFn &recv_tw_fn) {
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
// Broadcast current heads // Broadcast current heads
// Lane `i` holds the head of rank `i` and `is_token_in_rank` // Lane `i` holds the head of rank `i` and `is_token_in_rank`
EP_STATIC_ASSERT(kMaxNumRanks <= kWidth, "Too many ranks"); EP_STATIC_ASSERT(kMaxNumRanks <= kWarpSize, "Too many ranks");
int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks]; int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks];
#pragma unroll #pragma unroll
for (int i = 0; i < kNumRanks; ++i) for (int i = 0; i < kNumRanks; ++ i) if (shfl_sync(is_token_in_rank, i)) {
if (shfl_sync(is_token_in_rank, i, kWidth)) { slot_indices[num_topk_ranks] = shfl_sync(head_idx, i) % num_max_recv_tokens;
slot_indices[num_topk_ranks] = shfl_sync(head_idx, i, kWidth) % num_max_recv_tokens; topk_ranks[num_topk_ranks ++] = i;
topk_ranks[num_topk_ranks++] = i; }
}
EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks); EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);
// Reduce data // Reduce data
for (int i = lane_id; i < hidden_int4; i += kWidth) { #pragma unroll
for (int i = lane_id; i < hidden_int4; i += kWarpSize) {
// Read buffers
// TODO: maybe too many registers here
int4 recv_value_int4[kMaxNumRanks];
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);
// Reduce all-to-all results
float values[kDtypePerInt4] = {0}; float values[kDtypePerInt4] = {0};
#pragma unroll
// Temporary buffer for (int j = 0; j < num_topk_ranks; ++ j) {
int4 temp; auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++j) { for (int k = 0; k < kDtypePerInt4; ++ k)
temp = recv_fn(topk_ranks[j], slot_indices[j], i); values[k] += static_cast<float>(recv_value_dtypes[k]);
const dtype_t *d = reinterpret_cast<const dtype_t *>(&temp);
#pragma unroll
for (int k = 0; k < kDtypePerInt4; ++k)
values[k] += static_cast<float>(d[k]);
} }
int4 out_int4; // Cast back to `dtype_t` and write
dtype_t *out_dtypes = reinterpret_cast<dtype_t *>(&out_int4); int4 out_int4;
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
#pragma unroll #pragma unroll
for (int j = 0; j < kDtypePerInt4; ++j) for (int j = 0; j < kDtypePerInt4; ++ j)
out_dtypes[j] = static_cast<dtype_t>(values[j]); out_dtypes[j] = static_cast<dtype_t>(values[j]);
st_na_global(combined_row + i, out_int4); st_na_global(combined_row + i, out_int4);
} }
// Reduce `topk_weights` // Reduce `topk_weights`
if (lane_id < num_topk) { if (lane_id < num_topk) {
float value = 0; float value = 0;
for (int i = 0; i < num_topk_ranks; ++i) #pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i)
value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id); value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id);
st_na_global(combined_topk_weights + lane_id, value); st_na_global(combined_topk_weights + lane_id, value);
} }
...@@ -1385,16 +1374,16 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx, int lane_id, i ...@@ -1385,16 +1374,16 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx, int lane_id, i
return topk_ranks[0]; return topk_ranks[0];
} }
template <bool kLowLatencyMode, int kNumRDMARanks, typename dtype_t, int kNumCombineForwarderWarps, template <bool kLowLatencyMode,
int kNumRDMARanks,
typename dtype_t,
int kNumCombineForwarderWarps,
int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks), int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks),
int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
? kNumCombineForwarderWarps / kNumRDMARanks
: 1,
int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder, int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder,
int kNumRDMAReceivers = kNumForwarders + NUM_MAX_NVL_PEERS> int kNumRDMAReceivers = kNumForwarders>
__global__ void __global__ void __launch_bounds__((1 + NUM_MAX_NVL_PEERS) * kWarpSize, 1)
__launch_bounds__((NUM_MAX_NVL_PEERS + kNumForwarders) * kEmulatedWarpSize + kWarpSize, 1) combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_token_in_rank,
combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_token_in_rank,
const int4 *x, const float *topk_weights, const int4 *bias_0, const int4 *bias_1, const int4 *x, const float *topk_weights, const int4 *bias_0, const int4 *bias_1,
const int *combined_rdma_head, const int *combined_nvl_head, const SourceMeta *src_meta, const int *combined_rdma_head, const int *combined_nvl_head, const SourceMeta *src_meta,
const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum, const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
...@@ -1403,445 +1392,342 @@ __launch_bounds__((NUM_MAX_NVL_PEERS + kNumForwarders) * kEmulatedWarpSize + kWa ...@@ -1403,445 +1392,342 @@ __launch_bounds__((NUM_MAX_NVL_PEERS + kNumForwarders) * kEmulatedWarpSize + kWa
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks) { int num_ranks) {
enum class WarpRole { kNVLSender, kNVLAndRDMAForwarder, kRDMAReceiver, kCoordinator }; enum class WarpRole {
kNVLSender,
const auto sm_id = static_cast<int>(blockIdx.x); kNVLAndRDMAForwarder,
const auto num_threads = static_cast<int>(blockDim.x); kRDMAReceiver,
const auto num_warps = num_threads / kEmulatedWarpSize - 1; kRDMACoordinator,
auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id() % kEmulatedWarpSize; kNVLCoordinator
const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2; };
const bool is_rdma_receiver_sm = sm_id % 2 == 1;
__shared__ rocshmem::rocshmem_ctx_t ctx; __shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx); rocshmem::rocshmem_wg_ctx_create(0, &ctx);
EP_DEVICE_ASSERT(num_topk <= kEmulatedWarpSize); const auto sm_id = static_cast<int>(blockIdx.x);
EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0); const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
const auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
channel_id = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t)); const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));
// NOTES: we decouple a channel into 2 SMs // NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
auto role_meta = [=]() -> std::pair<WarpRole, int> {
auto warp_id = thread_id / kEmulatedWarpSize; const auto role_meta = [=]() -> std::pair<WarpRole, int> {
if (not is_rdma_receiver_sm) { if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
if (warp_id < NUM_MAX_NVL_PEERS) { return {WarpRole::kNVLSender, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
auto shuffled_warp_id = warp_id; } else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; if(warp_id < kNumForwarders) {
return {WarpRole::kNVLSender, shuffled_warp_id}; return {WarpRole::kNVLAndRDMAForwarder, (warp_id + channel_id) % kNumForwarders};
} else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) {
auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS;
shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders;
return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id};
} else { } else {
return {WarpRole::kCoordinator, 0}; return {WarpRole::kRDMACoordinator, 0};
} }
} else { } else {
if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { if(warp_id < kNumForwarders) {
return {WarpRole::kRDMAReceiver, warp_id}; return {WarpRole::kRDMAReceiver, warp_id};
} else { } else {
return {WarpRole::kCoordinator, 0}; return {WarpRole::kNVLCoordinator, 0};
} }
} }
}(); }();
auto warp_role = role_meta.first; auto warp_role = role_meta.first;
auto warp_id = role_meta.second; auto target_rank = role_meta.second; // Not applicable for RDMA senders
EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + kNumForwarders + 1); EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + 1);
auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks; auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;
// This approach is designed to sync multiple warps in a loop // This approach is designed to sync multiple warps in a loop
constexpr int num_sync_large_iteration = 64; constexpr int num_sync_large_iteration = 64;
__shared__ volatile int rdma_receiver_counter[1]; constexpr int rdma_warp_counters = kNumRDMARanks * num_sync_large_iteration;
__shared__ volatile int rdma_forwarder_counter[1]; __shared__ volatile int sync_large_warp_counters[2 * rdma_warp_counters];
__shared__ volatile uint8_t for (int i = thread_id; i < 2 * rdma_warp_counters; i += num_threads) {
sync_large_warp_counters[2 * kNumRDMARanks * num_sync_large_iteration];
if (threadIdx.x == 0) {
rdma_receiver_counter[0] = 0;
rdma_forwarder_counter[0] = 0;
}
for (int i = thread_id; i < 2 * kNumRDMARanks * num_sync_large_iteration; i += num_threads) {
sync_large_warp_counters[i] = 0; sync_large_warp_counters[i] = 0;
} }
__syncthreads(); __syncthreads();
if (warp_role == WarpRole::kNVLSender) { if (warp_role == WarpRole::kNVLSender) {
// NVL producers if(warp_id >= NUM_MAX_NVL_PEERS) {
const auto dst_nvl_rank = warp_id; return;
}
const auto dst_nvl_rank = target_rank;
// NVL layouts // NVL layouts
// NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank]; auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
auto nvl_channel_x = auto nvl_channel_x = AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank) auto nvl_channel_topk_weights = AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
.advance_also(local_buffer_ptr); auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr);
auto nvl_channel_src_meta = auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens,
NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank)
.advance_also(local_buffer_ptr);
auto nvl_channel_topk_weights =
AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk,
NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank)
.advance_also(local_buffer_ptr);
auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS,
channel_id, num_channels, dst_nvl_rank)
.advance_also(dst_buffer_ptr);
auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS,
channel_id, num_channels, nvl_rank)
.advance_also(local_buffer_ptr);
// Get tasks for each RDMA lane // Get tasks for each RDMA lane
int token_start_idx = 0, token_end_idx = 0; int token_start_idx = 0, token_end_idx = 0;
if (lane_id < kNumRDMARanks) { if(lane_id < kNumRDMARanks) {
int prefix_idx = int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
(lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
token_start_idx = gbl_channel_prefix_matrix[prefix_idx]; token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
token_end_idx = (prefix_idx == num_channels * num_ranks - 1) token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
? num_tokens
: gbl_channel_prefix_matrix[prefix_idx + 1];
} }
syncwarp(); syncwarp();
// NOTES: here the cached value of each lane is only responsible for a single RDMA buffer // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
EP_STATIC_ASSERT(kNumRDMARanks <= kEmulatedWarpSize, EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
"Invalid number of RDMA peers");
// Iterate over all tokens and send by chunks // Iterate over all tokens and send by chunks
while (true) { while(true) {
// Exit if possible // Exit if possible
if (__all_sync(kFullWarpMask, token_start_idx >= token_end_idx)) if(__all_sync(kFullWarpMask, token_start_idx >= token_end_idx))
break; break;
// Decide next RDMA buffer to send // Decide next RDMA buffer to send
bool is_lane_ready = false; bool is_lane_ready = false;
auto start_time = wall_clock64(); auto start_time = wall_clock64();
while (true) {
while(true) {
int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx; int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx;
is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and
num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens;
num_max_nvl_chunked_send_tokens;
if (__any_sync(kFirstHalfMask, is_lane_ready)) if(__any_sync(kFullWarpMask, is_lane_ready))
break;
if (__any_sync(kSecondHalfMask, is_lane_ready))
break; break;
// Retry // Retry
if (lane_id < kNumRDMARanks and token_start_idx < token_end_idx) if(lane_id < kNumRDMARanks and token_start_idx < token_end_idx)
cached_channel_head_idx = cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id);
ld_volatile_global(nvl_channel_head.buffer() + lane_id);
// Timeout check // Timeout check
long long int elapsed_time = if(wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
wall_clock64() > start_time ? wall_clock64() - start_time : 0; printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, "
if (elapsed_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { "RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n",
printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst " channel_id,
"NVL: %d, RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n", rdma_rank,
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, nvl_rank,
ld_volatile_global(nvl_channel_head.buffer() + lane_id), dst_nvl_rank,
cached_channel_tail_idx, token_start_idx, token_end_idx); lane_id,
ld_volatile_global(nvl_channel_head.buffer() + lane_id),
cached_channel_tail_idx,
token_start_idx,
token_end_idx);
trap(); trap();
} }
__builtin_amdgcn_s_sleep(1);
} }
// Sync token start index and count // Sync token start index and count
for (int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++current_rdma_idx) { for(int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++current_rdma_idx) {
if (shfl_sync((token_start_idx >= token_end_idx) or (not is_lane_ready), if(shfl_sync((token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))
current_rdma_idx, kEmulatedWarpSize))
continue; continue;
// Sync token start index // Sync token start index
auto token_idx = static_cast<int64_t>( auto token_idx = static_cast<int64_t>(shfl_sync(token_start_idx, current_rdma_idx));
shfl_sync(token_start_idx, current_rdma_idx, kEmulatedWarpSize)); int num_tokens_in_chunk = shfl_sync(min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx);
int num_tokens_in_chunk =
shfl_sync(min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx),
current_rdma_idx, kEmulatedWarpSize);
// Send by chunk // Send by chunk
for (int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) { for(int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) {
// Get an empty slot // Get an empty slot
int dst_slot_idx = 0; int dst_slot_idx = 0;
if (lane_id == current_rdma_idx) { if(lane_id == current_rdma_idx) {
dst_slot_idx = dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma;
(cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma; dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx;
dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma +
dst_slot_idx;
} }
dst_slot_idx = shfl_sync(dst_slot_idx, current_rdma_idx, kEmulatedWarpSize); dst_slot_idx = shfl_sync(dst_slot_idx, current_rdma_idx);
// Copy data // Copy data
auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4; auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x + token_idx * hidden_int4; auto shifted_x = x + token_idx * hidden_int4;
UNROLLED_WARP_COPY_EMULATED(5, lane_id, hidden_int4, shifted_x_buffers, UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
shifted_x, ld_nc_global, st_na_global);
// Copy source meta // Copy source meta
if (lane_id == num_topk) if(lane_id == 0)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx));
ld_nc_global(src_meta + token_idx));
// Copy `topk_weights` // Copy `topk_weights`
if (lane_id < num_topk) if(lane_id < num_topk)
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id,
lane_id, ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
} }
lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0; lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
} }
// Move queue tail // Move queue tail
syncwarp(); syncwarp();
if (lane_id < kNumRDMARanks and is_lane_ready) if(lane_id < kNumRDMARanks and is_lane_ready) {
st_relaxed_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx); st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
}
} }
} else { } else {
if(warp_id > kNumForwarders) {
return;
}
// Combiners and coordinators // Combiners and coordinators
// RDMA symmetric layout // RDMA symmetric layout
auto hidden_bytes = hidden_int4 * sizeof(int4); auto hidden_bytes = hidden_int4 * sizeof(int4);
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk); auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
auto rdma_channel_data = SymBuffer<int8_t>( auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
kNumRDMARanks, channel_id, num_channels); auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head =
SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail =
SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
// NVL layouts // NVL layouts
void *local_nvl_buffer = buffer_ptrs[nvl_rank]; void* local_nvl_buffer = buffer_ptrs[nvl_rank];
void *nvl_buffers[NUM_MAX_NVL_PEERS]; void* nvl_buffers[NUM_MAX_NVL_PEERS];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
nvl_buffers[i] = buffer_ptrs[i]; nvl_buffers[i] = buffer_ptrs[i];
auto nvl_channel_x = auto nvl_channel_x = AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
NUM_MAX_NVL_PEERS, channel_id, num_channels) auto nvl_channel_topk_weights = AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
.advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers); auto nvl_channel_head = AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer);
auto nvl_channel_src_meta = auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens,
NUM_MAX_NVL_PEERS, channel_id, num_channels)
.advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_topk_weights =
AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk,
NUM_MAX_NVL_PEERS, channel_id, num_channels)
.advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_head =
AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS,
channel_id, num_channels, nvl_rank)
.advance_also(local_nvl_buffer);
auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS,
channel_id, num_channels)
.advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
// Combiner warp synchronization // Combiner warp synchronization
__shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS]; __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS];
__shared__ volatile bool forwarder_retired[kNumForwarders]; __shared__ volatile bool forwarder_retired[kNumForwarders];
__shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks]; __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];
__shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers]; __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];
auto sync_forwarder_smem = [&]() {
if (lane_id == 0) {
// volatile int ret = __hip_atomic_fetch_add(
// &rdma_forwarder_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
volatile int ret = atomicAdd((int*)&rdma_forwarder_counter[0], 1);
}
syncwarp();
while (rdma_forwarder_counter[0] < (kNumForwarders + 1)) {
}
};
auto sync_rdma_receiver_smem = [&]() {
if (lane_id == 0) {
// volatile int ret = __hip_atomic_fetch_add(
// &rdma_receiver_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
volatile int ret = atomicAdd((int*)&rdma_receiver_counter[0], 1);
}
syncwarp();
while (rdma_receiver_counter[0] < (kNumRDMAReceivers + 1)) {
}
};
if (warp_role == WarpRole::kNVLAndRDMAForwarder) { if (warp_role == WarpRole::kNVLAndRDMAForwarder) {
// Receive from NVL ranks and forward to RDMA ranks // Receive from NVL ranks and forward to RDMA ranks
// NOTES: this part is using "large warps" for each RDMA ranks // NOTES: this part is using "large warps" for each RDMA ranks
const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder; const auto dst_rdma_rank = target_rank / kNumWarpsPerForwarder;
const auto sub_warp_id = warp_id % kNumWarpsPerForwarder; const auto sub_warp_id = target_rank % kNumWarpsPerForwarder;
auto send_buffer = dst_rdma_rank == rdma_rank auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank);
? rdma_channel_data.recv_buffer(dst_rdma_rank) // auto sync_large_warp = [=]() {
: rdma_channel_data.send_buffer(dst_rdma_rank); // if(kNumWarpsPerForwarder == 1) {
auto sync_large_warp = [=](const int iter, const int mode) { // syncwarp();
// } else {
// // asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * kWarpSize));
// // __syncthreads();
// syncwarp();
// }
// };
auto sync_large_warp = [=](const int iter, const int mode) {
if (kNumWarpsPerForwarder == 1) { if (kNumWarpsPerForwarder == 1) {
syncwarp(); syncwarp();
} else { } else {
// LDS index to store for sync
// LDS index to store for sync int lds_dst_rdma_rank = dst_rdma_rank + (iter % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
int lds_dst_rdma_rank = dst_rdma_rank + //reset index in the LDS to avoid race condition due to warp scheduling
(iter % num_sync_large_iteration) * kNumRDMARanks + int reset_idx = dst_rdma_rank + ((iter + num_sync_large_iteration/2) % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
mode * kNumRDMARanks * num_sync_large_iteration; // // if (lane_id==0)
// reset index in the LDS to avoid race condition due to warp scheduling // // printf("rank %d dst_rdma_rank %d iter %d warp_id %d val %d\n", rank, dst_rdma_rank, iter, warp_id, sync_large_warp_counters[lds_dst_rdma_rank]);
int reset_idx = auto start_time = wall_clock64();
dst_rdma_rank + if (lane_id == 0){
((iter + num_sync_large_iteration / 2) % num_sync_large_iteration) * volatile int ret = atomicAdd((int*)&sync_large_warp_counters[lds_dst_rdma_rank], 1);
kNumRDMARanks + }
mode * kNumRDMARanks * num_sync_large_iteration; syncwarp();
auto start_time = clock64(); //The while(...) loop polls the counter until all warps have arrived
if (lane_id == 0) { if (lane_id == 0){
volatile int ret = while (sync_large_warp_counters[lds_dst_rdma_rank] < (kNumWarpsPerForwarder)){
// __hip_atomic_fetch_add(&sync_large_warp_counters[lds_dst_rdma_rank], 1, if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
// __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); printf("DeepEP combine sync timeout. current num_sync_large_iteration %d. double it.\n", num_sync_large_iteration );
atomicAdd((int*)&sync_large_warp_counters[lds_dst_rdma_rank], 1); trap();
} }
syncwarp();
// The while(...) loop polls the counter until all warps have arrived
if (lane_id == 0) {
while (sync_large_warp_counters[lds_dst_rdma_rank] <
(kNumWarpsPerForwarder)) {
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine sync timeout. current "
"num_sync_large_iteration %d. double it.\n",
num_sync_large_iteration);
trap();
} }
} }
} syncwarp();
syncwarp(); if (lane_id == 0 && sync_large_warp_counters[reset_idx] == kNumWarpsPerForwarder){
if (lane_id == 0 && sync_large_warp_counters[reset_idx] = 0;
sync_large_warp_counters[reset_idx] == kNumWarpsPerForwarder) { }
sync_large_warp_counters[reset_idx] = 0; syncwarp();
}
syncwarp();
} }
}; };
EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= kNumCombineForwarderWarps, "Barriers are not enough");
"Barriers are not enough");
// Advance to the corresponding NVL buffer // Advance to the corresponding NVL buffer, 基于原本指针进行的地址偏移
nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4);
hidden_int4);
nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma); nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
nvl_channel_topk_weights.advance(dst_rdma_rank * nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
nvl_channel_head.advance(dst_rdma_rank); nvl_channel_head.advance(dst_rdma_rank);
nvl_channel_tail.advance(dst_rdma_rank); nvl_channel_tail.advance(dst_rdma_rank);
// Clean shared memory and sync // Clean shared memory and sync
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
"Invalid number of NVL peers"); lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[target_rank][lane_id] = 0) : 0;
lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[warp_id][lane_id] = 0) : 0; lane_id == 0 ? (forwarder_retired[target_rank] = false) : false;
lane_id == 0 ? (forwarder_retired[warp_id] = false) : false; // sync_forwarder_smem();
sync_forwarder_smem(); __syncthreads();
// Get count and cached head // Get count and cached head
int cached_nvl_channel_tail_idx = 0; int cached_nvl_channel_tail_idx = 0;
int num_tokens_to_combine = int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
int num_tokens_prefix =
channel_id == 0
? 0
: rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
num_tokens_to_combine -= num_tokens_prefix; num_tokens_to_combine -= num_tokens_prefix;
num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS; combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS;
// Iterate over all tokens and combine by chunks // Iterate over all tokens and combine by chunks
for (int token_start_idx = 0; token_start_idx < num_tokens_to_combine; for(int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) {
token_start_idx += num_max_rdma_chunked_send_tokens) {
// Check destination queue emptiness, or wait a buffer to be released // Check destination queue emptiness, or wait a buffer to be released
auto token_end_idx = auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);
min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);
auto num_chunked_tokens = token_end_idx - token_start_idx; auto num_chunked_tokens = token_end_idx - token_start_idx;
auto start_time = wall_clock64(); auto start_time = wall_clock64();
while (sub_warp_id == 0 and lane_id == 0) { while(sub_warp_id == 0 and lane_id == 0) {
// Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
// num_chunked_tokens` Here, `token_start_idx` is the actual tail // Here, `token_start_idx` is the actual tail
int num_used_slots = int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));
token_start_idx -
ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)); if(num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
break; break;
// Timeout check // Timeout check
long long int elapsed_time = if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
wall_clock64() > start_time ? wall_clock64() - start_time : 0; printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n",
if (elapsed_time > NUM_TIMEOUT_CYCLES) { channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens);
printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: "
"%d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n",
channel_id, rdma_rank, nvl_rank, dst_rdma_rank,
ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)),
token_start_idx, num_chunked_tokens);
trap(); trap();
} }
} }
// sync_large_warp();
sync_large_warp(token_start_idx, 0); sync_large_warp(token_start_idx, 0);
// Combine and write to the RDMA buffer // Combine and write to the RDMA buffer
for (int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; for(int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) {
token_idx += kNumWarpsPerForwarder) {
// Read expected head // Read expected head
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
"Invalid number of RDMA peers");
int expected_head = -1; int expected_head = -1;
if (lane_id < NUM_MAX_NVL_PEERS) { if(lane_id < NUM_MAX_NVL_PEERS)
expected_head = ld_nc_global(combined_nvl_head + expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
token_idx * NUM_MAX_NVL_PEERS + lane_id);
expected_head < 0
? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1)
: (forwarder_nvl_head[warp_id][lane_id] = expected_head);
}
// Wait lanes to be ready // Wait lanes to be ready
start_time = wall_clock64(); start_time = wall_clock64();
while (cached_nvl_channel_tail_idx <= expected_head) { while(cached_nvl_channel_tail_idx <= expected_head) {
cached_nvl_channel_tail_idx = cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id));
ld_relaxed_sys_global(nvl_channel_tail.buffer(lane_id));
// Timeout check // Timeout check
long long int elapsed_time = if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) {
wall_clock64() > start_time ? wall_clock64() - start_time : 0; printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n",
if (elapsed_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) { channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head);
printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, "
"RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, "
"waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank,
cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine,
sub_warp_id, kNumWarpsPerForwarder, expected_head);
trap(); trap();
} }
__builtin_amdgcn_s_sleep(1);
} }
// Combine current token // Combine current token
auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
void *shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token; void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
auto recv_fn = [&](int src_nvl_rank, int slot_idx, auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
int hidden_int4_idx) -> int4 { auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
slot_idx * hidden_int4 + hidden_int4_idx); expected_head, lane_id,
}; hidden_int4, num_topk,
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { reinterpret_cast<int4*>(shifted),
return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
slot_idx * num_topk + topk_idx); num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
};
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS, kEmulatedWarpSize>(
expected_head >= 0, expected_head, lane_id, hidden_int4, num_topk,
reinterpret_cast<int4 *>(shifted),
reinterpret_cast<float *>(reinterpret_cast<int8_t *>(shifted) +
hidden_bytes + sizeof(SourceMeta)),
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
// Update head // Update head
if (lane_id < NUM_MAX_NVL_PEERS) if(lane_id < NUM_MAX_NVL_PEERS) {
expected_head < 0 expected_head < 0 ? (forwarder_nvl_head[target_rank][lane_id] = -expected_head - 1)
? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1) : (forwarder_nvl_head[target_rank][lane_id] = expected_head + 1);
: (forwarder_nvl_head[warp_id][lane_id] = expected_head + 1); }
} }
// sync_large_warp();
sync_large_warp(token_start_idx, 1); sync_large_warp(token_start_idx, 1);
// Issue RDMA send // Issue RDMA send
// TODO: Switch back to put_nbi_wave function if(sub_warp_id == kNumWarpsPerForwarder - 1) {
if (sub_warp_id == kNumWarpsPerForwarder - 1 && lane_id == 0) { if(dst_rdma_rank != rdma_rank) {
if (dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
rocshmem::rocshmem_ctx_schar_put_nbi( rocshmem::rocshmem_ctx_schar_put_nbi_wave(
ctx, ctx,
rdma_channel_data.recv_buffer(rdma_rank) + rdma_channel_data.recv_buffer(rdma_rank) +
rdma_slot_idx * num_bytes_per_rdma_token, rdma_slot_idx * num_bytes_per_rdma_token,
...@@ -1857,146 +1743,142 @@ __launch_bounds__((NUM_MAX_NVL_PEERS + kNumForwarders) * kEmulatedWarpSize + kWa ...@@ -1857,146 +1743,142 @@ __launch_bounds__((NUM_MAX_NVL_PEERS + kNumForwarders) * kEmulatedWarpSize + kWa
// Write new RDMA tail // Write new RDMA tail
syncwarp(); syncwarp();
if (lane_id == 0) if(lane_id == 0) {
rocshmem::rocshmem_ctx_ulong_atomic_add( rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, ctx, rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
} }
} }
// Retired // Retired
syncwarp(); syncwarp();
if (lane_id == 0) if(lane_id == 0) {
forwarder_retired[warp_id] = true; forwarder_retired[target_rank] = true;
} else if (warp_role == WarpRole::kRDMAReceiver) { }
} else if (warp_role == WarpRole::kRDMACoordinator) {
// Coordinator
// Sync shared memory status
// sync_forwarder_smem();
__syncthreads();
constexpr int num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
int last_nvl_head[kNumRDMARanks] = {0};
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) {
// Retired
if(__all_sync(kFullWarpMask, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
break;
{
// Find minimum head for NVL ranks
#pragma unroll
for(int i = 0; i < kNumRDMARanks; ++i) {
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for(int j = 0; j < num_warps_per_rdma_rank; ++j)
if(not forwarder_retired[i * num_warps_per_rdma_rank + j])
min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);
if(min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) {
st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);
}
}
}
// Nanosleep and let other warps work
__builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
}
} else if(warp_role == WarpRole::kRDMAReceiver) {
// Receive from RDMA ranks and write to the output tensor // Receive from RDMA ranks and write to the output tensor
// Clean shared memory and sync // Clean shared memory and sync
EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize); EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0; lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[target_rank][lane_id] = 0) : 0;
lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0; lane_id == 0 ? (rdma_receiver_retired[target_rank] = false) : 0;
sync_rdma_receiver_smem(); // sync_rdma_receiver_smem();
__syncthreads();
// The same tokens as the dispatch process // The same tokens as the dispatch process
int token_start_idx, token_end_idx; int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
token_end_idx);
// Iterate over all tokens and combine // Iterate over all tokens and combine
int cached_channel_tail_idx = 0; int cached_channel_tail_idx = 0;
for (int64_t token_idx = token_start_idx + warp_id; token_idx < token_end_idx; for(int64_t token_idx = token_start_idx + target_rank; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {
token_idx += kNumRDMAReceivers) {
// Read expected head // Read expected head
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
"Invalid number of RDMA peers");
int expected_head = -1; int expected_head = -1;
if (lane_id < kNumRDMARanks) { if(lane_id < kNumRDMARanks) {
expected_head = expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);
ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id); (expected_head < 0) ? (rdma_receiver_rdma_head[target_rank][lane_id] = -expected_head - 1)
(expected_head < 0) : (rdma_receiver_rdma_head[target_rank][lane_id] = expected_head);
? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1)
: (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head);
} }
// Wait lanes to be ready // Wait lanes to be ready
auto start_time = wall_clock64(); auto start_time = wall_clock64();
while (cached_channel_tail_idx <= expected_head) { while (cached_channel_tail_idx <= expected_head) {
cached_channel_tail_idx = cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
static_cast<int>(ld_relaxed_sys_global(rdma_channel_tail.buffer(lane_id)));
// Timeout check // Timeout check
long long int elapsed_time = if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
wall_clock64() > start_time ? wall_clock64() - start_time : 0; printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
if (elapsed_time > NUM_TIMEOUT_CYCLES) { channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head);
printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: "
"%d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx,
token_idx, expected_head);
trap(); trap();
} }
__builtin_amdgcn_s_sleep(1);
} }
syncwarp(); syncwarp();
// Combine current token // Combine current token
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
return ld_nc_global(reinterpret_cast<const int4 *>( auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
rdma_channel_data.recv_buffer(src_rdma_rank) + combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
slot_idx * num_bytes_per_rdma_token) + expected_head, lane_id,
hidden_int4_idx); hidden_int4, num_topk,
}; combined_x + token_idx * hidden_int4,
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { combined_topk_weights + token_idx * num_topk,
return ld_nc_global(reinterpret_cast<const float *>( num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
rdma_channel_data.recv_buffer(src_rdma_rank) +
slot_idx * num_bytes_per_rdma_token + hidden_bytes +
sizeof(SourceMeta)) +
topk_idx);
};
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks, kEmulatedWarpSize>(
expected_head >= 0, expected_head, lane_id, hidden_int4, num_topk,
combined_x + token_idx * hidden_int4,
combined_topk_weights + token_idx * num_topk, num_max_rdma_chunked_recv_tokens,
recv_fn, recv_tw_fn);
} }
// Retired // Retired
syncwarp(); syncwarp();
if (lane_id == 0) if(lane_id == 0) {
rdma_receiver_retired[warp_id] = true; rdma_receiver_retired[target_rank] = true;
} else { }
lane_id = get_lane_id(); } else if(warp_role == WarpRole::kNVLCoordinator) {
// Coordinator // Coordinator
// Sync shared memory status // Sync shared memory status
is_rdma_receiver_sm ? sync_rdma_receiver_smem() : sync_forwarder_smem(); // sync_rdma_receiver_smem();
__syncthreads();
const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks; const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
int last_rdma_head = 0; int last_rdma_head = 0;
int last_nvl_head[kNumRDMARanks] = {0}; int last_nvl_head[kNumRDMARanks] = {0};
int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0; int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0;
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
"Invalid number of forwarder warps");
while (true) { while(true) {
// Retired // Retired
if (is_rdma_receiver_sm and if(__all_sync(kFullWarpMask, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
__all_sync(kFullWarpMask,
lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
break;
if (not is_rdma_receiver_sm and
__all_sync(kFullWarpMask,
lane_id >= kNumForwarders or forwarder_retired[lane_id]))
break; break;
// Find minimum head for RDMA ranks // Find minimum head for RDMA ranks
if (is_rdma_receiver_sm) { {
int min_head = std::numeric_limits<int>::max(); int min_head = std::numeric_limits<int>::max();
#pragma unroll #pragma unroll
for (int i = 0; i < kNumRDMAReceivers; ++i) for(int i = 0; i < kNumRDMAReceivers; ++i)
if (not rdma_receiver_retired[i]) if(not rdma_receiver_retired[i])
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and
min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
lane_id < kNumRDMARanks) {
rocshmem::rocshmem_ctx_ulong_atomic_add( rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank)); translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
last_rdma_head = min_head; last_rdma_head = min_head;
} }
} else {
// Find minimum head for NVL ranks
#pragma unroll
for (int i = 0; i < kNumRDMARanks; ++i) {
int min_head = std::numeric_limits<int>::max();
for (int j = 0; j < num_warps_per_rdma_rank; ++j)
if (not forwarder_retired[i * num_warps_per_rdma_rank + j])
min_head = min(min_head,
forwarder_nvl_head[i * num_warps_per_rdma_rank + j]
[dst_nvl_rank]);
if (min_head != std::numeric_limits<int>::max() and
min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS)
st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i,
last_nvl_head[i] = min_head);
}
} }
// Nanosleep and let other warps work // Nanosleep and let other warps work
...@@ -2017,7 +1899,7 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights, ...@@ -2017,7 +1899,7 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode) { int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode) {
constexpr int kNumCombineForwarderWarps = 16; constexpr int kNumCombineForwarderWarps = 8;
#define COMBINE_LAUNCH_CASE(num_rdma_ranks) \ #define COMBINE_LAUNCH_CASE(num_rdma_ranks) \
{ \ { \
...@@ -2037,30 +1919,26 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights, ...@@ -2037,30 +1919,26 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
} \ } \
break break
int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1); auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder; int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder;
EP_HOST_ASSERT(num_forwarder_warps >= NUM_MAX_NVL_PEERS);
EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0); EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks - num_warps_per_forwarder >=
num_max_nvl_chunked_send_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens >= num_warps_per_forwarder);
EP_HOST_ASSERT(type == HIP_R_16BF); EP_HOST_ASSERT(type == HIP_R_16BF);
SETUP_LAUNCH_CONFIG(num_channels * 2, SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, (NUM_MAX_NVL_PEERS + 1) * kWarpSize, stream);
(NUM_MAX_NVL_PEERS + num_forwarder_warps) * kEmulatedWarpSize + kWarpSize,
stream);
SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE); SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE #undef COMBINE_LAUNCH_CASE
} }
} // namespace internode } // namespace internode
} // namespace deep_ep } // namespace deep_ep
#ifdef __clang__ // #ifdef __clang__
#pragma clang diagnostic pop // #pragma clang diagnostic pop
#endif // __clang__ // #endif // __clang__
#endif #endif
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#include "buffer.cuh"
#include "configs.cuh"
#include "launch_hip.cuh"
#include "utils_hip.cuh"
#ifndef DISABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
// TODO: fix unroll warnings
// #ifdef __clang__
// #pragma clang diagnostic push
// #pragma clang diagnostic ignored "-Wpass-failed"
// #pragma clang diagnostic ignored "-Wdeprecated-volatile"
// #endif // __clang__
namespace deep_ep {
namespace internode {
extern rocshmem::rocshmem_team_t cpu_rdma_team;
struct SourceMeta {
int src_rdma_rank, is_token_in_nvl_rank_bits;
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers");
__forceinline__ SourceMeta() = default;
// TODO: faster encoding
__device__ __forceinline__ SourceMeta(int rdma_rank, const bool *is_token_in_nvl_ranks) {
src_rdma_rank = rdma_rank;
is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0];
#pragma unroll
for (int i = 1; i < NUM_MAX_NVL_PEERS; ++i)
is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i;
}
__device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const {
return (is_token_in_nvl_rank_bits >> nvl_rank) & 1;
}
};
int get_source_meta_bytes() {
return sizeof(SourceMeta);
}
__host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_int4,
int num_scales,
int num_topk_idx,
int num_topk_weights) {
return static_cast<int>(ALIGN(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) +
num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
num_topk_weights * sizeof(float),
sizeof(int4)));
}
__host__ __device__ __forceinline__ std::pair<int, int>
get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
// Return `int32_t` offset and count to clean
return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) *
num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) /
sizeof(int),
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms};
}
__host__ __device__ __forceinline__ std::pair<int, int>
get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens,
int num_sms) {
// Return `int32_t` offset and to clean
EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0,
"Invalid size of `SourceMeta`");
return {
(num_nvl_recv_buffer_tokens *
(hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
num_topk_weights * sizeof(float) + sizeof(SourceMeta)) *
num_nvl_ranks * num_sms) /
sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
};
}
template <bool kLowLatencyMode>
__forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
const int nvl_rank) {
return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank;
}
template <bool kLowLatencyMode>
__forceinline__ __device__ void
nvshmem_barrier_with_same_gpu_idx(const rocshmem::rocshmem_team_t &rdma_team) {
// NOTE: shmem_device_barrier_all() might be an issue as
// it doesn't follow OpenSHMEM specification on ROCm
// kLowLatencyMode
// ? void(rocshmem::rocshmem_ctx_barrier(rocshmem::ROCSHMEM_CTX_DEFAULT, rdma_team))
// : rocshmem::rocshmem_barrier_all();
rocshmem::rocshmem_barrier_all();
}
template <bool kLowLatencyMode, int kNumRDMARanks>
__global__ void
notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
int num_experts, const bool *is_token_in_rank, int num_tokens, int num_channels,
int expert_alignment, const int rdma_clean_offset, const int rdma_num_int_clean,
const int nvl_clean_offset, const int nvl_num_int_clean,
int *rdma_channel_prefix_matrix, int *recv_rdma_rank_prefix_sum,
int *gbl_channel_prefix_matrix, int *recv_gbl_rank_prefix_sum,
void *rdma_buffer_ptr, void **buffer_ptrs, int **barrier_signal_ptrs, int rank,
const rocshmem::rocshmem_team_t rdma_team) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize,
lane_id = get_lane_id();
auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
auto num_rdma_experts = num_experts / kNumRDMARanks,
num_nvl_experts = num_rdma_experts / NUM_MAX_NVL_PEERS;
if (sm_id == 0) {
// Communication with others
// Global barrier: the first warp do intra-node sync, the second warp do internode sync
EP_DEVICE_ASSERT(num_warps > 1);
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
if (thread_id == kWarpSize)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
__syncthreads();
// Send numbers of tokens per rank/expert to RDMA ranks
auto rdma_buffer_ptr_int = reinterpret_cast<int *>(rdma_buffer_ptr);
auto rdma_recv_num_tokens_mixed = SymBuffer<int>(
rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks);
// Clean up for later data dispatch
EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <=
rdma_clean_offset * sizeof(int));
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
// Copy to send buffer
for (int i = thread_id; i < num_ranks; i += num_threads)
rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] =
num_tokens_per_rank[i];
for (int i = thread_id; i < num_experts; i += num_threads)
rdma_recv_num_tokens_mixed.send_buffer(
i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] =
num_tokens_per_expert[i];
if (thread_id < kNumRDMARanks)
rdma_recv_num_tokens_mixed.send_buffer(
thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] =
num_tokens_per_rdma_rank[thread_id];
__syncthreads();
// Issue send
// TODO: more light fence or barrier or signaling
// TODO: overlap EP barrier and NVL cleaning
if (thread_id < kNumRDMARanks) {
rocshmem::rocshmem_int_put_nbi(
rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
rdma_recv_num_tokens_mixed.send_buffer(thread_id),
NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank));
}
__syncthreads();
if (thread_id == 0)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
__syncthreads();
// NVL buffers
auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr;
auto nvl_recv_buffer = buffer_ptrs[nvl_rank];
auto nvl_reduced_num_tokens_per_expert =
Buffer<int>(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer);
auto nvl_send_num_tokens_per_rank =
AsymBuffer<int>(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
auto nvl_send_num_tokens_per_expert =
AsymBuffer<int>(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
auto nvl_recv_num_tokens_per_rank =
AsymBuffer<int>(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
auto nvl_recv_num_tokens_per_expert =
AsymBuffer<int>(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
// Clean up for later data dispatch
auto nvl_buffer_ptr_int = reinterpret_cast<int *>(buffer_ptrs[nvl_rank]);
EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes +
nvl_send_num_tokens_per_rank.total_bytes +
nvl_send_num_tokens_per_expert.total_bytes <=
nvl_clean_offset * sizeof(int));
for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
// Reduce number of tokens per expert into the NVL send buffer
// TODO: may use NVSHMEM reduction
EP_DEVICE_ASSERT(num_rdma_experts <= num_threads);
if (thread_id < num_rdma_experts) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumRDMARanks; ++i)
sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id];
nvl_reduced_num_tokens_per_expert[thread_id] = sum;
}
__syncthreads();
// Reduce RDMA received tokens
if (thread_id == 0) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumRDMARanks; ++i) {
sum +=
rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts];
recv_rdma_rank_prefix_sum[i] = sum;
}
while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)
;
*moe_recv_rdma_counter_mapped = sum;
}
// Send numbers of tokens per rank/expert to NVL ranks
EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads);
if (thread_id < NUM_MAX_NVL_PEERS) {
#pragma unroll
for (int i = 0; i < kNumRDMARanks; ++i)
nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] =
rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id];
for (int i = 0; i < num_nvl_experts; ++i)
nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] =
nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i];
}
memory_fence();
__syncthreads();
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
__syncthreads();
// Reduce number of tokens per rank/expert
EP_DEVICE_ASSERT(num_nvl_experts <= num_threads);
if (thread_id == 0) {
int sum = 0;
for (int i = 0; i < num_ranks; ++i) {
int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS;
sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank];
recv_gbl_rank_prefix_sum[i] = sum;
}
while (ld_volatile_global(moe_recv_counter_mapped) != -1)
;
*moe_recv_counter_mapped = sum;
}
if (thread_id < num_nvl_experts) {
int sum = 0;
#pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id];
sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1)
;
moe_recv_expert_counter_mapped[thread_id] = sum;
}
// Finally barrier
__syncthreads();
if (thread_id == kWarpSize)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
} else {
// Calculate meta data
int dst_rdma_rank = sm_id - 1;
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx,
token_end_idx);
// Iterate over tokens
int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0};
for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += kWarpSize) {
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t),
"Invalid number of NVL peers");
auto is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t *>(
is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS);
auto is_token_in_rank_values =
reinterpret_cast<const bool *>(&is_token_in_rank_uint64);
#pragma unroll
for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j)
per_nvl_rank_count[j] += is_token_in_rank_values[j];
total_count += (is_token_in_rank_uint64 != 0);
}
// Warp reduce
total_count = warp_reduce_sum(total_count);
#pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]);
// Write into channel matrix
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) *
num_channels +
channel_id] = per_nvl_rank_count[i];
rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count;
}
}
// Calculate prefix sum
__syncthreads();
if (thread_id == 0) {
auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels;
for (int i = 1; i < num_channels; ++i)
prefix_row[i] += prefix_row[i - 1];
}
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
if (thread_id < NUM_MAX_NVL_PEERS) {
auto prefix_row = gbl_channel_prefix_matrix +
(dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels;
for (int i = 1; i < num_channels; ++i)
prefix_row[i] += prefix_row[i - 1];
}
}
}
void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
int num_experts, const bool *is_token_in_rank, int num_tokens,
int num_channels, int hidden_int4, int num_scales, int num_topk,
int expert_alignment, int *rdma_channel_prefix_matrix,
int *recv_rdma_rank_prefix_sum, int *gbl_channel_prefix_matrix,
int *recv_gbl_rank_prefix_sum, void *rdma_buffer_ptr,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool low_latency_mode) {
#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) \
{ \
auto notify_dispatch_func = low_latency_mode ? notify_dispatch<true, num_rdma_ranks> \
: notify_dispatch<false, num_rdma_ranks>; \
LAUNCH_KERNEL_NON_COOPERATIVE( \
&cfg, notify_dispatch_func, num_tokens_per_rank, moe_recv_counter_mapped, num_ranks, \
num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, num_tokens_per_expert, \
moe_recv_expert_counter_mapped, num_experts, is_token_in_rank, num_tokens, \
num_channels, expert_alignment, rdma_clean_meta.first, rdma_clean_meta.second, \
nvl_clean_meta.first, nvl_clean_meta.second, rdma_channel_prefix_matrix, \
recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, cpu_rdma_team); \
} \
break
constexpr int kNumThreads = 256;
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
// Get clean meta
auto rdma_clean_meta =
get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks,
num_max_rdma_chunked_recv_tokens, num_channels);
auto nvl_clean_meta =
get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks,
NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <=
num_rdma_bytes);
EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <=
num_nvl_bytes);
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
// Launch kernel
SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream);
SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
#undef NOTIFY_DISPATCH_LAUNCH_CASE
}
// At most 8 RDMA ranks to be sent
constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
return num_rdma_ranks < 8 ? num_rdma_ranks : 8;
}
template <bool kLowLatencyMode,
int kNumRDMARanks,
bool kCachedMode,
int kNumDispatchRDMASenderWarps,
int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>
__global__ void __launch_bounds__(((1 + NUM_MAX_NVL_PEERS) * kWarpSize), 1)
dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
SourceMeta *recv_src_meta, const int4 *x, const float *x_scales,
const int64_t *topk_idx, const float *topk_weights, int *send_rdma_head,
int *send_nvl_head, int *recv_rdma_channel_prefix_matrix,
int *recv_gbl_channel_prefix_matrix, const int *rdma_channel_prefix_matrix,
const int *recv_rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix,
const int *recv_gbl_rank_prefix_sum, const bool *is_token_in_rank, int num_tokens,
int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride,
int scale_hidden_stride, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks) {
enum class WarpRole {
kRDMASender, // 从x写入到RDMA发送缓存
kRDMASenderCoordinator, // 从RDMA发送缓存写入到远端rdma_rank接收缓存
kRDMAAndNVLForwarder, // 从RDMA接收缓存转写到ipc nvl缓存
kForwarderCoordinator, // 向远端RDMA确认接收
kNVLReceivers // 从nvl缓存写入到recv_x
};
__shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx);
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
const auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
channel_id = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
EP_DEVICE_ASSERT(num_warps == 1 + NUM_MAX_NVL_PEERS);
const auto role_meta = [=]() -> std::pair<WarpRole, int> {
if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
if(warp_id < kNumDispatchRDMASenderWarps) {
return {WarpRole::kRDMASender, -1};
} else if(warp_id == kNumDispatchRDMASenderWarps) {
return {WarpRole::kRDMASenderCoordinator, -1};
}
} else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
if(warp_id < NUM_MAX_NVL_PEERS) {
return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
} else {
return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};
}
} else {
return {WarpRole::kNVLReceivers, (warp_id + channel_id + 1) % NUM_MAX_NVL_PEERS};
}
}();
auto warp_role = role_meta.first;
auto target_rank = role_meta.second; // Not applicable for RDMA senders
// if(lane_id==0){
// printf("tid=%d, bid=%d, warp_role=%d\n", threadIdx.x, blockIdx.x, warp_role);
// }
// RDMA symmetric layout
auto hidden_bytes = hidden_int4 * sizeof(int4);
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
// NVL buffer layouts
// NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"
void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr;
int rs_wr_rank = 0, ws_rr_rank = 0;
if (warp_role == WarpRole::kRDMAAndNVLForwarder)
rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;
if (warp_role == WarpRole::kNVLReceivers)
rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
// Allocate buffers
auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_x_scales = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_topk_idx = AsymBuffer<int>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_topk_weights = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr);
auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
// RDMA sender warp synchronization
__shared__ volatile int rdma_send_next_token_idx;
__shared__ volatile int rdma_send_channel_tail[kNumRDMARanks];
__shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks];
// NVL and RDMA coordinate Forward warp synchronization
__shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];
__shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS];
// Place the main logic of your kernel here, using the parameters above.
if(warp_role == WarpRole::kRDMASender) {
/*
这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
它首先获取当前通道的任务范围,然后清理共享内存,接着计算并发送本通道中的令牌数量。
然后,它遍历所有的令牌,读取每个令牌的RDMA秩的存在性,获取顺序锁,计算下一个尾部位置,存储RDMA头部,更新最后一个令牌尾部,释放顺序锁,并广播尾部位置。
最后,它复制相关的数据到对称发送缓冲区。
kRDMASender主要目的是将发送信息x, x_scale,source_meta, topk_idx, topk_weight等信息填充进入rdma发送缓存,
期间要同步warp直接对token的依序操作,以及和kForwarderCoordinator, kRDMASenderCoordinator内存同步。
同时在复制操作时, 使用ld.global.nc.L1::no_allocate.L2::256B, st.global.L1::no_allocate减少L1/L2缓存使用。
*/
// 获取任务范围
int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// 清理共享内存
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA秩数量");
if(warp_id == 0 && lane_id == 0) {
rdma_send_next_token_idx = token_start_idx;
}
if(warp_id == 0 && lane_id < kNumRDMARanks) {
rdma_send_channel_tail[lane_id] = 0;
rdma_send_channel_next_tail[lane_id] = 0;
}
// 发送本通道中的令牌数量,通过 `-value - 1` 表示
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= kWarpSize, "无效的NVL对等体数量");
// 对于每个目标RDMA秩,以warp为单位进行迭代。计算发送缓冲区的值,并存储在rdma_channel_meta.send_buffer中
// 用于填充rdma_channel_meta.send_buffer本节点发送到远端rank, rdma_rank的起始index和结束index
for(int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank);
if (lane_id < NUM_MAX_NVL_PEERS) {
dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
} else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
}
syncwarp();
if (dst_rdma_rank != rdma_rank) {
rocshmem::rocshmem_ctx_int_put_nbi_wave(
ctx, rdma_channel_meta.recv_buffer(rdma_rank),
rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
}
rocshmem::rocshmem_ctx_quiet(ctx);
// sync_rdma_sender_smem();
__syncthreads();
// 遍历令牌并复制到缓冲区
int64_t token_idx;
int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1;
auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
for(token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
// 读取RDMA秩的存在性
uint64_t is_token_in_rank_uint64 = 0;
if(lane_id < kNumRDMARanks) {
is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
}
// 获得处理数据的自旋锁,获得锁后才会处理一些数据信息
while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
// 等待
}
syncwarp();
// 获取下一个尾部位置
int rdma_tail_idx = -1;
if(is_token_in_rank_uint64 != 0) {
rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++;
// 与kForwarderCoordinator相互配合,调节发送数据的频率
while(rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
}
}
syncwarp();
// 存储RDMA头部以供合并
if(lane_id < kNumRDMARanks && !kCachedMode) {
send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
}
// 更新最后一个令牌尾部
if(last_rdma_tail_idx >= 0) {
st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
}
last_rdma_tail_idx = rdma_tail_idx;
// 释放顺序锁
if(lane_id == 0) {
rdma_send_next_token_idx += 1;
}
// 广播尾部位置
SourceMeta src_meta;
int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
void* dst_send_buffers[kNumTopkRDMARanks];
/*
该for循环主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作
*/
#pragma unroll
for(int i = 0, slot_idx; i < kNumRDMARanks; ++i) {
// 使用__shfl_sync函数在warp内同步并广播rdma_tail_idx的值
if((slot_idx = shfl_sync(rdma_tail_idx, i)) >= 0) {
// warp 所有线程参与,rdma_tail_idx默认为-1, 只有对应rdma rank需要发送时, rdma_tail_idx才会>=0
// 计算slot_idx在接收缓冲区中的位置
slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;
// 存储当前RDMA秩到topk_ranks数组中
topk_ranks[num_topk_ranks] = i;
// 广播is_token_in_rank_uint64的值到所有线程,并解释为布尔数组
auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i);
auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);
// 如果当前lane_id等于num_topk_ranks,则更新src_meta
if(lane_id == num_topk_ranks) {
src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);
}
// 计算目标发送缓冲区的地址,并存储在dst_send_buffers数组中
// 获取到发送地址, num_topk_ranks-1 是需要发送的ranks数
dst_send_buffers[num_topk_ranks++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token;
}
}
EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);
// 复制 `x` 到对称发送缓冲区
auto st_broadcast = [=](const int key, const int4& value) {
#pragma unroll
for(int j = 0; j < num_topk_ranks; ++j) {
st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);
}
};
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);
#pragma unroll
for(int i = 0; i < num_topk_ranks; ++i) {
dst_send_buffers[i] = reinterpret_cast<int4*>(dst_send_buffers[i]) + hidden_int4;
}
// 复制源元数据到对称发送缓冲区
if(lane_id < num_topk_ranks) {
st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);
}
#pragma unroll
for(int i = 0; i < num_topk_ranks; ++i) {
dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;
}
// 复制 `x_scales` 到对称发送缓冲区
#pragma unroll
for(int i = lane_id; i < num_scales; i += kWarpSize) {
auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
// auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
// auto value = ld_nc_global(x_scales + offset);
#pragma unroll
for(int j = 0; j < num_topk_ranks; ++j) {
st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
}
}
#pragma unroll
for(int i = 0; i < num_topk_ranks; ++i) {
dst_send_buffers[i] = reinterpret_cast<float*>(dst_send_buffers[i]) + num_scales;
}
// 复制 `topk_idx` 和 `topk_weights` 到对称发送缓冲区
#pragma unroll
for(int i = lane_id; i < num_topk * num_topk_ranks; i += kWarpSize) {
auto rank_idx = i / num_topk, copy_idx = i % num_topk;
auto idx_value = static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));
auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx);
st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);
}
}
// 结尾部分
// 获取顺序锁
while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
// 等待
}
syncwarp();
// 更新最后一个令牌尾部
if(last_rdma_tail_idx >= 0) {
st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
}
// 释放顺序锁
if(lane_id == 0) {
rdma_send_next_token_idx += 1;
}
} else if(warp_role == WarpRole::kRDMASenderCoordinator) {
/*
这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
它首先计算每个RDMA秩需要发送的令牌数,然后在所有RDMA秩之间循环,检查是否有令牌需要发送。
如果有,它将计算本次需要发出的令牌数,并发出相应的RDMA发送请求。
最后,它更新相关的尾部位置,以便下次循环时可以正确地计算需要发送的令牌数。
kRDMASenderCoordinator使用了同sm内存一致性(ld.acquire.cta.s32),
nvshmem内存一致性(nvshmem_fence)和原子操作(nvshmemx_signal_op),减少硬同步,提升整体效率。
*/
if(warp_id > kNumDispatchRDMASenderWarps) {
return;
}
// 确保最大接收令牌数可以被最大发送令牌数整除,以避免缓冲区分割问题
EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
// 同步共享内存,确保所有线程在继续之前都达到了这一点
// sync_rdma_sender_smem();
__syncthreads();
// 计算当前通道需要发送的令牌数
int num_tokens_to_send = 0;
if(lane_id < kNumRDMARanks) {
num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id];
if(channel_id > 0)
num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];
}
// 记录上次发出的尾部位置
int last_issued_tail = 0;
// 当有任何RDMA秩需要发送令牌时,继续循环
while(__any_sync(kFullWarpMask, num_tokens_to_send > 0)) {
for(int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) {
// 计算目标RDMA秩
int dst_rdma_rank = (i + channel_id) % kNumRDMARanks;
// 获取同步后的需要发送的令牌数
synced_num_tokens_to_send = shfl_sync(num_tokens_to_send, dst_rdma_rank);
if(synced_num_tokens_to_send == 0)
continue; // 如果没有令牌需要发送,则跳过
// 读取进度
auto synced_last_issued_tail = shfl_sync(last_issued_tail, dst_rdma_rank);
auto processed_tail = ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank));
auto num_tokens_processed = processed_tail - synced_last_issued_tail;
// 如果处理的令牌数不等于需要发送的令牌数,并且处理的令牌数小于最大发送令牌数,则跳过
if(num_tokens_processed != synced_num_tokens_to_send && num_tokens_processed < num_max_rdma_chunked_send_tokens)
continue;
// 计算本次需要发出的令牌数
auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens);
EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 && num_tokens_to_issue <= synced_num_tokens_to_send);
// 发出RDMA发送请求
if(dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
rocshmem::rocshmem_ctx_schar_put_nbi_wave(
ctx,
rdma_channel_data.recv_buffer(rdma_rank) +
dst_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) +
dst_slot_idx * num_bytes_per_rdma_token,
num_bytes_per_rdma_token * num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
rocshmem::rocshmem_ctx_quiet(ctx);
} else {
// 对于本地RDMA秩,使用较轻的内存屏障
memory_fence();
}
// 更新尾部位置
syncwarp();
if(lane_id == dst_rdma_rank) {
last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue;
// 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
}
} // while(__any(num_tokens_to_send > 0))
} else if(warp_role == WarpRole::kRDMAAndNVLForwarder) {
/*
这段代码的主要功能是在一个CUDA内核中协调从RDMA消费者到NVL生产者的转发操作。
它首先计算目标NVL秩和目标秩,然后等待相关的计数器到达。
接着,它检查目标队列是否为空,或者等待一个缓冲区被释放。
然后,它找到下一个源RDMA秩,并遍历RDMA缓冲区中的每一个令牌,复制相关的数据到NVL缓冲区。
最后,它同步头部和尾部索引,并标记通道为退役状态。
*/
// RDMA消费者和NVL生产者
const auto dst_nvl_rank = target_rank; // 目标NVL秩
const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank; // 目标秩
const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks); // 目标秩专家开始
const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks); // 目标秩专家结束
// 等待计数器到达
int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0;
EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
auto start_time = wall_clock64();
if(lane_id < kNumRDMARanks) {
while(true) {
// 对应于kRDMASender中的数据写入
auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank); // 是nvl节点的起始地址
auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); // nvl节点的结束地址
auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2); // 本rdma节点的起始地址
auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1); // 本节点的结束地址
if(meta_0 < 0 && meta_1 < 0 && meta_2 < 0 && meta_3 < 0) {
// 通知NVL秩
int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;
EP_DEVICE_ASSERT(start_sum >= 0 && end_sum >= 0 && end_sum >= start_sum);
st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1);
st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1);
// 保存从RDMA通道接收的令牌计数
src_rdma_channel_prefix = -meta_2 - 1;
auto src_rdma_channel_prefix_1 = -meta_3 - 1;
num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; // 是远端 rdma_rank 会发送给当前节点的token数量
if(!kCachedMode)
recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1;
src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; // 对应的远端 rdma_rank 的起始index, 存在线程0之中
EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);
break;
}
// 超时检查
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3);
trap();
}
}
}
syncwarp();
// 移动缓存的头部
send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank;
// 等待共享内存被清理
// sync_forwarder_smem();
__syncthreads();
// 开始准备处理接受数据,直到所有的数据接受完成。
// 转发从RDMA缓冲区的令牌
// 注意:总是从本地秩开始
int src_rdma_rank = sm_id % kNumRDMARanks;
int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0;
int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0;
while(__any_sync(kFullWarpMask, num_tokens_to_recv_from_rdma > 0)) {
// 检查nvl目标队列是否为空,或者等待一个缓冲区被释放
start_time = wall_clock64();
// 用于给kNVLReceivers进行互动,控制数据的传输速度
while(lane_id == 0) {
int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;
if(num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens)
break;
cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
// 超时检查
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail);
trap();
}
}
syncwarp();
// 找到下一个源RDMA秩(轮询)
start_time = wall_clock64();
while(true) {
src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;
if(shfl_sync(num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) {
if(lane_id == src_rdma_rank && cached_rdma_channel_head == cached_rdma_channel_tail)
cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
if(shfl_sync(cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) {
break;
}
}
// 超时检查
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n",
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma);
trap();
}
}
auto src_rdma_head = shfl_sync(cached_rdma_channel_head, src_rdma_rank);
auto src_rdma_tail = shfl_sync(cached_rdma_channel_tail, src_rdma_rank);
// 遍历RDMA缓冲区中的每一个令牌
for(int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) {
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
// 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入
void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));
if(lane_id == src_rdma_rank) {
num_tokens_to_recv_from_rdma -= 1;
}
bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
if(lane_id == src_rdma_rank) {
auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1;
rdma_nvl_token_idx += is_in_dst_nvl_rank;
if(!kCachedMode)
send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head;
}
if(!is_in_dst_nvl_rank)
continue;
// 获取一个空闲槽位
int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens;
// 复制数据
UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
reinterpret_cast<int4*>(shifted),
ld_nc_global, st_na_global);
shifted = reinterpret_cast<int4*>(shifted) + hidden_int4;
// 复制源元数据
if(lane_id == 0)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
shifted = reinterpret_cast<SourceMeta*>(shifted) + 1;
// 复制 `x_scales`
UNROLLED_WARP_COPY(1, lane_id, num_scales,
nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
reinterpret_cast<float*>(shifted),
ld_nc_global, st_na_global);
shifted = reinterpret_cast<float*>(shifted) + num_scales;
// 复制 `topk_idx` 和 `topk_weights`
if(lane_id < num_topk) {
// 读取
auto idx_value = ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id);
shifted = reinterpret_cast<int*>(shifted) + num_topk;
auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted) + lane_id);
// 转换和写入
idx_value = (idx_value >= dst_rank_expert_begin && idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value);
weight_value = idx_value >= 0 ? weight_value : 0.0f;
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value);
}
// 在NVL缓冲区不足的情况下,提前停止
if((++num_tokens_sent) == num_max_nvl_chunked_send_tokens)
src_rdma_tail = i + 1;
}
// 同步头部索引
if(lane_id == src_rdma_rank)
forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail);
// 移动尾部索引,与kNVLReceivers互相通信使用
syncwarp();
if(lane_id == 0) {
st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);
}
}
// Retired
syncwarp();
if(lane_id == 0) {
forward_channel_retired[dst_nvl_rank] = true;
}
} else if(warp_role == WarpRole::kForwarderCoordinator) {
/*
这段代码的主要功能是在一个CUDA内核中协调转发器的逻辑。
它首先检查当前warp是否是额外的转发器协调warp,如果是,则直接退出。
然后,它清理共享内存,并初始化转发通道的头部和退役状态。
接着,它进入一个无限循环,在循环中,它找到最小的头部,如果所有的通道都已退役,则退出循环。
否则,它更新远程头部,并进行纳秒级睡眠,以让其他warp工作。
*/
// Extra warps for forwarder coordinator should exit directly
if (warp_id > NUM_MAX_NVL_PEERS)
return;
// 转发warp协调器
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
// 清理共享内存
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "无效的NVL对等体数量");
#pragma unroll
for(int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += kWarpSize)
forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0;
if(lane_id < NUM_MAX_NVL_PEERS)
forward_channel_retired[lane_id] = false;
// sync_forwarder_smem();
__syncthreads();
int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0;
while(true) {
// 找到最小的头部
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for(int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
if(!forward_channel_retired[i])
min_head = min(min_head, forward_channel_head[i][target_rdma]);
if(__all_sync(kFullWarpMask, min_head == std::numeric_limits<int>::max())) {
break;
}
// 更新远程头部
if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){
rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
last_head = min_head;
}
// 纳秒级睡眠并让其他warp工作 // Nanosleep and let other warps work
__builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
}
} else if(warp_role == WarpRole::kNVLReceivers) {
if(warp_id >= NUM_MAX_NVL_PEERS) {
return;
}
// Place the main logic of your kernel here, using the parameters above.
// NVL消费者
// 从屏障结果中检索秩偏移(每个通道的寄存器存储一个RDMA秩)
int src_nvl_rank = target_rank, total_offset = 0;
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
if(lane_id < kNumRDMARanks && lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0)
total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];
// 接收通道偏移
int start_offset = 0, end_offset = 0, num_tokens_to_recv;
auto start_time = wall_clock64();
while(lane_id < kNumRDMARanks) {
start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id);
end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id);
if(start_offset < 0 && end_offset < 0) {
start_offset = -start_offset - 1, end_offset = -end_offset - 1;
total_offset += start_offset;
break;
}
// 超时检查
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset);
trap();
}
}
num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);
// 保存以供合并使用
if(lane_id < kNumRDMARanks && !kCachedMode)
recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset;
syncwarp();
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
while(num_tokens_to_recv > 0) {
// 通过通道0检查通道状态
start_time = wall_clock64();
while(lane_id == 0) {
// 准备复制
if(cached_channel_head_idx != cached_channel_tail_idx)
break;
cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer());
// 超时检查
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n",
channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx);
trap();
}
}
// 同步队列尾部
cached_channel_tail_idx = shfl_sync(cached_channel_tail_idx, 0);
// 复制数据
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
for(int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) {
int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens;
auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);
int64_t recv_token_idx = shfl_sync(total_offset, meta.src_rdma_rank);
(lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;
// 复制数据
UNROLLED_WARP_COPY(5,
lane_id,
hidden_int4,
recv_x + recv_token_idx * hidden_int4,
nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4,
ld_nc_global,
st_na_global);
// 复制源元数据
if(lane_id == 0 && !kCachedMode)
st_na_global(recv_src_meta + recv_token_idx, meta);
// 复制比例
UNROLLED_WARP_COPY(1,
lane_id,
num_scales,
recv_x_scales + recv_token_idx * num_scales,
nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,
ld_nc_global,
st_na_global);
// 复制 `topk_idx` 和 `topk_weights`
if(lane_id < num_topk) {
auto recv_idx = recv_token_idx * num_topk + lane_id;
auto buffer_idx = token_idx_in_buffer * num_topk + lane_id;
st_na_global(recv_topk_idx + recv_idx, static_cast<int64_t>(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));
st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));
}
}
// 移动队列
syncwarp();
if(lane_id == 0) {
st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
}
} // while(num_tokens_to_recv > 0)
}
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
}
void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
void *recv_src_meta, const void *x, const float *x_scales, const int64_t *topk_idx,
const float *topk_weights, int *send_rdma_head, int *send_nvl_head,
int *recv_rdma_channel_prefix_matrix, int *recv_gbl_channel_prefix_matrix,
const int *rdma_channel_prefix_matrix, const int *recv_rdma_rank_prefix_sum,
const int *gbl_channel_prefix_matrix, const int *recv_gbl_rank_prefix_sum,
const bool *is_token_in_rank, int num_tokens, int hidden_int4, int num_scales,
int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride,
void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels,
bool low_latency_mode) {
constexpr int kNumDispatchRDMASenderWarps = 7;
// Make sure never OOB
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \
{ \
auto dispatch_func = \
low_latency_mode \
? (is_cached_dispatch \
? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> \
: dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>) \
: (is_cached_dispatch \
? dispatch<false, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> \
: dispatch<false, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>); \
LAUNCH_KERNEL_NON_COOPERATIVE( \
&cfg, dispatch_func, reinterpret_cast<int4 *>(recv_x), recv_x_scales, recv_topk_idx, \
recv_topk_weights, reinterpret_cast<SourceMeta *>(recv_src_meta), \
reinterpret_cast<const int4 *>(x), x_scales, topk_idx, topk_weights, send_rdma_head, \
send_nvl_head, recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
recv_gbl_rank_prefix_sum, is_token_in_rank, num_tokens, hidden_int4, num_scales, \
num_topk, num_experts, scale_token_stride, scale_hidden_stride, rdma_buffer_ptr, \
num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs, \
num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks); \
} \
break
EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));
SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
(1 + NUM_MAX_NVL_PEERS) * kWarpSize, stream);
SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}
template <bool kLowLatencyMode>
__global__ void __launch_bounds__(1024, 1)
cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset,
const int nvl_num_int_clean, int *combined_rdma_head, int num_combined_tokens,
int num_channels, const int *rdma_channel_prefix_matrix,
const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
void **buffer_ptrs, int **barrier_signal_ptrs, int rank, int num_ranks,
bool is_cached_dispatch, const rocshmem::rocshmem_team_t rdma_team) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x);
auto num_threads = static_cast<int>(blockDim.x);
auto num_warps = num_threads / kWarpSize;
auto warp_id = thread_id / kWarpSize;
auto lane_id = get_lane_id();
auto nvl_rank = rank % NUM_MAX_NVL_PEERS;
auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
// Using two SMs, which clean the RDMA/NVL buffer respectively
if (sm_id == 0) {
// Barrier for RDMA
if (thread_id == 0)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
__syncthreads();
// Clean
auto rdma_buffer_ptr_int = reinterpret_cast<int *>(rdma_buffer_ptr);
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
rocshmem::rocshmem_fence();
__syncthreads();
// Barrier again
if (thread_id == 0)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
} else if (sm_id == 1) {
// Barrier for NVL
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
__syncthreads();
// Clean
auto nvl_buffer_ptr_int = reinterpret_cast<int *>(buffer_ptrs[nvl_rank]);
for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
memory_fence();
__syncthreads();
// Barrier again
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
} else if (sm_id == 2) {
if (is_cached_dispatch)
return;
EP_DEVICE_ASSERT(num_warps >= num_channels);
EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize);
// Iterate in reverse order
if (lane_id < num_rdma_ranks and warp_id < num_channels) {
int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx,
token_end_idx);
// NOTES: `1 << 25` is a heuristic large number
int last_head = 1 << 25;
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
auto current_head =
__ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
if (current_head < 0) {
combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
} else {
last_head = current_head;
}
}
}
} else {
if (is_cached_dispatch)
return;
EP_DEVICE_ASSERT(num_warps >= num_channels);
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and
rdma_rank_prefix_sum != nullptr);
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers");
if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks;
dst_rdma_rank += num_channels * 2 - 3) {
// Iterate in reverse order
int token_start_idx =
warp_id == 0
? 0
: rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
int token_end_idx =
rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
token_start_idx += shift, token_end_idx += shift;
// NOTES: `1 << 25` is a heuristic large number
int last_head = 1 << 25;
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
auto current_head =
__ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
if (current_head < 0) {
combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1;
} else {
last_head = current_head;
}
}
}
}
}
}
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_ranks, int num_channels, int num_combined_tokens,
int *combined_rdma_head, const int *rdma_channel_prefix_matrix,
const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode) {
const int num_threads = ::max(128, kWarpSize * num_channels);
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
// Get clean meta
auto rdma_clean_meta =
get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks,
num_max_rdma_chunked_recv_tokens, num_channels);
auto nvl_clean_meta =
get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks,
NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <=
num_rdma_bytes);
EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <=
num_nvl_bytes);
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_channels * 2 > 3);
// Launch kernel
auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
LAUNCH_KERNEL_NON_COOPERATIVE(
&cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second,
nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens,
num_channels, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head,
rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, num_ranks, is_cached_dispatch,
cpu_rdma_team);
}
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx,
int lane_id, int hidden_int4, int num_topk,
int4* combined_row, float* combined_topk_weights,
int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
// Broadcast current heads
// Lane `i` holds the head of rank `i` and `is_token_in_rank`
EP_STATIC_ASSERT(kMaxNumRanks <= kWarpSize, "Too many ranks");
int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks];
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i) if (shfl_sync(is_token_in_rank, i)) {
slot_indices[num_topk_ranks] = shfl_sync(head_idx, i) % num_max_recv_tokens;
topk_ranks[num_topk_ranks ++] = i;
}
EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);
// Reduce data
#pragma unroll
for (int i = lane_id; i < hidden_int4; i += kWarpSize) {
// Read buffers
// TODO: maybe too many registers here
int4 recv_value_int4[kMaxNumRanks];
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);
// Reduce all-to-all results
float values[kDtypePerInt4] = {0};
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j) {
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
#pragma unroll
for (int k = 0; k < kDtypePerInt4; ++ k)
values[k] += static_cast<float>(recv_value_dtypes[k]);
}
// Cast back to `dtype_t` and write
int4 out_int4;
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j)
out_dtypes[j] = static_cast<dtype_t>(values[j]);
st_na_global(combined_row + i, out_int4);
}
// Reduce `topk_weights`
if (lane_id < num_topk) {
float value = 0;
#pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i)
value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id);
st_na_global(combined_topk_weights + lane_id, value);
}
// Return the minimum top-k rank
return topk_ranks[0];
}
template <bool kLowLatencyMode,
int kNumRDMARanks,
typename dtype_t,
int kNumCombineForwarderWarps,
int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks),
int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder,
int kNumRDMAReceivers = kNumForwarders>
__global__ void __launch_bounds__((1 + NUM_MAX_NVL_PEERS) * kWarpSize, 1)
combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_token_in_rank,
const int4 *x, const float *topk_weights, const int4 *bias_0, const int4 *bias_1,
const int *combined_rdma_head, const int *combined_nvl_head, const SourceMeta *src_meta,
const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks) {
enum class WarpRole {
kNVLSender,
kNVLAndRDMAForwarder,
kRDMAReceiver,
kRDMACoordinator,
kNVLCoordinator
};
__shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx);
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
const auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
channel_id = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));
// NOTES: we decouple a channel into 2 SMs
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
const auto role_meta = [=]() -> std::pair<WarpRole, int> {
if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
return {WarpRole::kNVLSender, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
} else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
if(warp_id < kNumForwarders) {
return {WarpRole::kNVLAndRDMAForwarder, (warp_id + channel_id) % kNumForwarders};
} else {
return {WarpRole::kRDMACoordinator, 0};
}
} else {
if(warp_id < kNumForwarders) {
return {WarpRole::kRDMAReceiver, warp_id};
} else {
return {WarpRole::kNVLCoordinator, 0};
}
}
}();
auto warp_role = role_meta.first;
auto target_rank = role_meta.second; // Not applicable for RDMA senders
EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + 1);
auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;
// This approach is designed to sync multiple warps in a loop
constexpr int num_sync_large_iteration = 64;
constexpr int rdma_warp_counters = kNumRDMARanks * num_sync_large_iteration;
__shared__ volatile int sync_large_warp_counters[2 * rdma_warp_counters];
for (int i = thread_id; i < 2 * rdma_warp_counters; i += num_threads) {
sync_large_warp_counters[i] = 0;
}
__syncthreads();
if (warp_role == WarpRole::kNVLSender) {
if(warp_id >= NUM_MAX_NVL_PEERS) {
return;
}
const auto dst_nvl_rank = target_rank;
// NVL layouts
// NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
auto nvl_channel_x = AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
auto nvl_channel_topk_weights = AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr);
auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
// Get tasks for each RDMA lane
int token_start_idx = 0, token_end_idx = 0;
if(lane_id < kNumRDMARanks) {
int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
}
syncwarp();
// NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
// Iterate over all tokens and send by chunks
while(true) {
// Exit if possible
if(__all_sync(kFullWarpMask, token_start_idx >= token_end_idx))
break;
// Decide next RDMA buffer to send
bool is_lane_ready = false;
auto start_time = wall_clock64();
while(true) {
int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx;
is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and
num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens;
if(__any_sync(kFullWarpMask, is_lane_ready))
break;
// Retry
if(lane_id < kNumRDMARanks and token_start_idx < token_end_idx)
cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id);
// Timeout check
if(wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, "
"RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n",
channel_id,
rdma_rank,
nvl_rank,
dst_nvl_rank,
lane_id,
ld_volatile_global(nvl_channel_head.buffer() + lane_id),
cached_channel_tail_idx,
token_start_idx,
token_end_idx);
trap();
}
}
// Sync token start index and count
for(int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++current_rdma_idx) {
if(shfl_sync((token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))
continue;
// Sync token start index
auto token_idx = static_cast<int64_t>(shfl_sync(token_start_idx, current_rdma_idx));
int num_tokens_in_chunk = shfl_sync(min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx);
// Send by chunk
for(int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) {
// Get an empty slot
int dst_slot_idx = 0;
if(lane_id == current_rdma_idx) {
dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma;
dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx;
}
dst_slot_idx = shfl_sync(dst_slot_idx, current_rdma_idx);
// Copy data
auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x + token_idx * hidden_int4;
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
// Copy source meta
if(lane_id == 0)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx));
// Copy `topk_weights`
if(lane_id < num_topk)
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id,
ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
}
lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
}
// Move queue tail
syncwarp();
if(lane_id < kNumRDMARanks and is_lane_ready) {
st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
}
}
} else {
if(warp_id > kNumForwarders) {
return;
}
// Combiners and coordinators
// RDMA symmetric layout
auto hidden_bytes = hidden_int4 * sizeof(int4);
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
// NVL layouts
void* local_nvl_buffer = buffer_ptrs[nvl_rank];
void* nvl_buffers[NUM_MAX_NVL_PEERS];
#pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
nvl_buffers[i] = buffer_ptrs[i];
auto nvl_channel_x = AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_topk_weights = AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
auto nvl_channel_head = AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer);
auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
// Combiner warp synchronization
__shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS];
__shared__ volatile bool forwarder_retired[kNumForwarders];
__shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];
__shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];
if (warp_role == WarpRole::kNVLAndRDMAForwarder) {
// Receive from NVL ranks and forward to RDMA ranks
// NOTES: this part is using "large warps" for each RDMA ranks
const auto dst_rdma_rank = target_rank / kNumWarpsPerForwarder;
const auto sub_warp_id = target_rank % kNumWarpsPerForwarder;
auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank);
// auto sync_large_warp = [=]() {
// if(kNumWarpsPerForwarder == 1) {
// syncwarp();
// } else {
// // asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * kWarpSize));
// // __syncthreads();
// syncwarp();
// }
// };
auto sync_large_warp = [=](const int iter, const int mode) {
if (kNumWarpsPerForwarder == 1) {
syncwarp();
} else {
// LDS index to store for sync
int lds_dst_rdma_rank = dst_rdma_rank + (iter % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
//reset index in the LDS to avoid race condition due to warp scheduling
int reset_idx = dst_rdma_rank + ((iter + num_sync_large_iteration/2) % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
// // if (lane_id==0)
// // printf("rank %d dst_rdma_rank %d iter %d warp_id %d val %d\n", rank, dst_rdma_rank, iter, warp_id, sync_large_warp_counters[lds_dst_rdma_rank]);
auto start_time = wall_clock64();
if (lane_id == 0){
volatile int ret = atomicAdd((int*)&sync_large_warp_counters[lds_dst_rdma_rank], 1);
}
syncwarp();
//The while(...) loop polls the counter until all warps have arrived
if (lane_id == 0){
while (sync_large_warp_counters[lds_dst_rdma_rank] < (kNumWarpsPerForwarder)){
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine sync timeout. current num_sync_large_iteration %d. double it.\n", num_sync_large_iteration );
trap();
}
}
}
syncwarp();
if (lane_id == 0 && sync_large_warp_counters[reset_idx] == kNumWarpsPerForwarder){
sync_large_warp_counters[reset_idx] = 0;
}
syncwarp();
}
};
EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= kNumCombineForwarderWarps, "Barriers are not enough");
// Advance to the corresponding NVL buffer, 基于原本指针进行的地址偏移
nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4);
nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
nvl_channel_head.advance(dst_rdma_rank);
nvl_channel_tail.advance(dst_rdma_rank);
// Clean shared memory and sync
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[target_rank][lane_id] = 0) : 0;
lane_id == 0 ? (forwarder_retired[target_rank] = false) : false;
// sync_forwarder_smem();
__syncthreads();
// Get count and cached head
int cached_nvl_channel_tail_idx = 0;
int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
num_tokens_to_combine -= num_tokens_prefix;
num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS;
// Iterate over all tokens and combine by chunks
for(int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) {
// Check destination queue emptiness, or wait a buffer to be released
auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);
auto num_chunked_tokens = token_end_idx - token_start_idx;
auto start_time = wall_clock64();
while(sub_warp_id == 0 and lane_id == 0) {
// Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
// Here, `token_start_idx` is the actual tail
int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));
if(num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
break;
// Timeout check
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n",
channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens);
trap();
}
}
// sync_large_warp();
sync_large_warp(token_start_idx, 0);
// Combine and write to the RDMA buffer
for(int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) {
// Read expected head
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
int expected_head = -1;
if(lane_id < NUM_MAX_NVL_PEERS)
expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
// Wait lanes to be ready
start_time = wall_clock64();
while(cached_nvl_channel_tail_idx <= expected_head) {
cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id));
// Timeout check
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) {
printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head);
trap();
}
}
// Combine current token
auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
reinterpret_cast<int4*>(shifted),
reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
// Update head
if(lane_id < NUM_MAX_NVL_PEERS) {
expected_head < 0 ? (forwarder_nvl_head[target_rank][lane_id] = -expected_head - 1)
: (forwarder_nvl_head[target_rank][lane_id] = expected_head + 1);
}
}
// sync_large_warp();
sync_large_warp(token_start_idx, 1);
// Issue RDMA send
if(sub_warp_id == kNumWarpsPerForwarder - 1) {
if(dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
rocshmem::rocshmem_ctx_schar_put_nbi_wave(
ctx,
rdma_channel_data.recv_buffer(rdma_rank) +
rdma_slot_idx * num_bytes_per_rdma_token,
rdma_channel_data.send_buffer(dst_rdma_rank) +
rdma_slot_idx * num_bytes_per_rdma_token,
num_chunked_tokens * num_bytes_per_rdma_token,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
rocshmem::rocshmem_ctx_quiet(ctx);
} else {
memory_fence();
}
// Write new RDMA tail
syncwarp();
if(lane_id == 0) {
rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
}
}
// Retired
syncwarp();
if(lane_id == 0) {
forwarder_retired[target_rank] = true;
}
} else if (warp_role == WarpRole::kRDMACoordinator) {
// Coordinator
// Sync shared memory status
// sync_forwarder_smem();
__syncthreads();
constexpr int num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
int last_nvl_head[kNumRDMARanks] = {0};
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) {
// Retired
if(__all_sync(kFullWarpMask, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
break;
{
// Find minimum head for NVL ranks
#pragma unroll
for(int i = 0; i < kNumRDMARanks; ++i) {
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for(int j = 0; j < num_warps_per_rdma_rank; ++j)
if(not forwarder_retired[i * num_warps_per_rdma_rank + j])
min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);
if(min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) {
st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);
}
}
}
// Nanosleep and let other warps work
__builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
}
} else if(warp_role == WarpRole::kRDMAReceiver) {
// Receive from RDMA ranks and write to the output tensor
// Clean shared memory and sync
EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[target_rank][lane_id] = 0) : 0;
lane_id == 0 ? (rdma_receiver_retired[target_rank] = false) : 0;
// sync_rdma_receiver_smem();
__syncthreads();
// The same tokens as the dispatch process
int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// Iterate over all tokens and combine
int cached_channel_tail_idx = 0;
for(int64_t token_idx = token_start_idx + target_rank; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {
// Read expected head
EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
int expected_head = -1;
if(lane_id < kNumRDMARanks) {
expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);
(expected_head < 0) ? (rdma_receiver_rdma_head[target_rank][lane_id] = -expected_head - 1)
: (rdma_receiver_rdma_head[target_rank][lane_id] = expected_head);
}
// Wait lanes to be ready
auto start_time = wall_clock64();
while (cached_channel_tail_idx <= expected_head) {
cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
// Timeout check
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head);
trap();
}
}
syncwarp();
// Combine current token
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
combined_x + token_idx * hidden_int4,
combined_topk_weights + token_idx * num_topk,
num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
}
// Retired
syncwarp();
if(lane_id == 0) {
rdma_receiver_retired[target_rank] = true;
}
} else if(warp_role == WarpRole::kNVLCoordinator) {
// Coordinator
// Sync shared memory status
// sync_rdma_receiver_smem();
__syncthreads();
const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
int last_rdma_head = 0;
int last_nvl_head[kNumRDMARanks] = {0};
int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0;
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");
while(true) {
// Retired
if(__all_sync(kFullWarpMask, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
break;
// Find minimum head for RDMA ranks
{
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for(int i = 0; i < kNumRDMAReceivers; ++i)
if(not rdma_receiver_retired[i])
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
rocshmem::rocshmem_ctx_ulong_atomic_add(
ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
last_rdma_head = min_head;
}
}
// Nanosleep and let other warps work
__builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
}
}
}
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
}
void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
const bool *is_combined_token_in_rank, const void *x, const float *topk_weights,
const void *bias_0, const void *bias_1, const int *combined_rdma_head,
const int *combined_nvl_head, const void *src_meta,
const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode) {
constexpr int kNumCombineForwarderWarps = 8;
#define COMBINE_LAUNCH_CASE(num_rdma_ranks) \
{ \
auto combine_func = \
low_latency_mode \
? combine<true, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps> \
: combine<false, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>; \
LAUNCH_KERNEL_NON_COOPERATIVE( \
&cfg, combine_func, reinterpret_cast<int4 *>(combined_x), combined_topk_weights, \
is_combined_token_in_rank, reinterpret_cast<const int4 *>(x), topk_weights, \
reinterpret_cast<const int4 *>(bias_0), reinterpret_cast<const int4 *>(bias_1), \
combined_rdma_head, combined_nvl_head, reinterpret_cast<const SourceMeta *>(src_meta), \
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, \
num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs, \
num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks); \
} \
break
int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder;
EP_HOST_ASSERT(num_forwarder_warps >= NUM_MAX_NVL_PEERS);
EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
EP_HOST_ASSERT(type == HIP_R_16BF);
SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, (NUM_MAX_NVL_PEERS + 1) * kWarpSize, stream);
SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
}
} // namespace internode
} // namespace deep_ep
// #ifdef __clang__
// #pragma clang diagnostic pop
// #endif // __clang__
#endif
...@@ -393,9 +393,9 @@ __global__ void __launch_bounds__(kNumThreads, 1) ...@@ -393,9 +393,9 @@ __global__ void __launch_bounds__(kNumThreads, 1)
total_offset; total_offset;
num_tokens_to_recv -= total_offset; num_tokens_to_recv -= total_offset;
} }
total_offset = __shfl_sync(kFullWarpMask, total_offset, 0); total_offset = shfl_sync(total_offset, 0);
total_offset += rank_offset; total_offset += rank_offset;
num_tokens_to_recv = __shfl_sync(kFullWarpMask, num_tokens_to_recv, 0); num_tokens_to_recv = shfl_sync(num_tokens_to_recv, 0);
// Shared tail indices for different warps // Shared tail indices for different warps
__shared__ volatile int shared_channel_tail_idx[kNumRanks]; __shared__ volatile int shared_channel_tail_idx[kNumRanks];
...@@ -583,7 +583,7 @@ __global__ void cached_notify_combine(void **buffer_ptrs, int *send_head, int nu ...@@ -583,7 +583,7 @@ __global__ void cached_notify_combine(void **buffer_ptrs, int *send_head, int nu
? __ldg(send_head + token_idx * kNumRanks + rank_id) ? __ldg(send_head + token_idx * kNumRanks + rank_id)
: -1; : -1;
for (int i = 0; i < min(kWarpSize, token_idx_tail - token_start_idx + 1); ++i) { for (int i = 0; i < min(kWarpSize, token_idx_tail - token_start_idx + 1); ++i) {
const int head = __shfl_sync(kFullWarpMask, current_head, i); const int head = shfl_sync(current_head, i);
if (head < 0) { if (head < 0) {
if (lane_id == i) if (lane_id == i)
expected_head = -last_head - 1; expected_head = -last_head - 1;
...@@ -606,7 +606,7 @@ void cached_notify_combine(void **buffer_ptrs, int *send_head, int num_channels, ...@@ -606,7 +606,7 @@ void cached_notify_combine(void **buffer_ptrs, int *send_head, int num_channels,
barrier_signal_ptrs, rank); \ barrier_signal_ptrs, rank); \
break break
const int num_threads = std::max(128, kWarpSize * num_ranks); const int num_threads = ::max(128, kWarpSize * num_ranks);
EP_HOST_ASSERT(num_ranks <= num_threads); EP_HOST_ASSERT(num_ranks <= num_threads);
EP_HOST_ASSERT(num_threads <= 1024); EP_HOST_ASSERT(num_threads <= 1024);
EP_HOST_ASSERT(1 + num_channels <= num_channels * 2); EP_HOST_ASSERT(1 + num_channels <= num_channels * 2);
......
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "buffer.cuh"
#include "configs.cuh"
#include "hip/hip_runtime.h"
#include "launch_hip.cuh"
#include "utils_hip.cuh"
namespace deep_ep {
namespace intranode {
template <int kNumRanks>
__global__ void
notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped,
const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
int64_t *moe_recv_tokens_per_experts, int num_experts, int num_tokens,
int num_channels, const bool *is_token_in_rank, int *channel_prefix_matrix,
int *rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void **buffer_ptrs, int **barrier_signal_ptrs, int rank) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto lane_id = thread_id % kWarpSize, warp_id = thread_id / kWarpSize,
num_warps = num_threads / kWarpSize;
if (sm_id == 0) {
// Barrier first
barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);
int *per_rank_buffer, *per_expert_buffer;
if (thread_id < kNumRanks) {
per_rank_buffer = static_cast<int *>(buffer_ptrs[thread_id]);
per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks;
}
// After this loop:
// - `per_rank_buffer[rank][i, j]` means the number of tokens from rank i to rank j
// - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert
// j
int num_experts_per_rank = num_experts / kNumRanks;
if (thread_id < kNumRanks) {
per_rank_buffer[rank * kNumRanks + thread_id] = num_tokens_per_rank[thread_id];
#pragma unroll
for (int i = 0; i < num_experts_per_rank; ++i)
per_expert_buffer[rank * num_experts_per_rank + i] =
num_tokens_per_expert[thread_id * num_experts_per_rank + i];
}
// Wait for all ranks to be finished
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
// Sum per-rank counts and return to CPU
// Also pre-compute the prefix sum for data sending
auto local_per_rank_buffer = static_cast<int *>(buffer_ptrs[rank]);
if (thread_id < kNumRanks) {
#pragma unroll
for (int i = 1; i < kNumRanks; ++i)
local_per_rank_buffer[i * kNumRanks + thread_id] +=
local_per_rank_buffer[(i - 1) * kNumRanks + thread_id];
if (thread_id == rank)
*moe_recv_counter_mapped =
local_per_rank_buffer[(kNumRanks - 1) * kNumRanks + rank];
}
// Sum per-experts counts and return to CPU
auto local_per_expert_buffer = local_per_rank_buffer + kNumRanks * kNumRanks;
if (thread_id < num_experts_per_rank) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumRanks; ++i)
sum += local_per_expert_buffer[i * num_experts_per_rank + thread_id];
sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
moe_recv_expert_counter_mapped[thread_id] = sum;
moe_recv_tokens_per_experts[thread_id] = sum;
}
__syncthreads();
// Copy rank size prefix matrix to another tensor
#pragma unroll
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
rank_prefix_matrix_copy[i] = local_per_rank_buffer[i];
// Extra memset for later communication queue
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
local_per_expert_buffer[i] = 0;
// Barrier
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
} else {
int dst_rank = sm_id - 1;
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx,
token_end_idx);
// Iterate over tokens
int count = 0;
for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += kWarpSize)
count += is_token_in_rank[i * kNumRanks + dst_rank];
count = warp_reduce_sum(count);
if (lane_id == 0)
channel_prefix_matrix[dst_rank * num_channels + channel_id] = count;
}
__syncthreads();
// Pre-compute prefix sum for all channels
if (thread_id == 0) {
#pragma unroll
for (int i = 1; i < num_channels; ++i)
channel_prefix_matrix[dst_rank * num_channels + i] +=
channel_prefix_matrix[dst_rank * num_channels + i - 1];
}
}
}
void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
int64_t *moe_recv_tokens_per_experts, int num_experts, int num_tokens,
const bool *is_token_in_rank, int *channel_prefix_matrix,
int *rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void **buffer_ptrs, int **barrier_signal_ptrs, int rank, hipStream_t stream,
int num_channels) {
#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, notify_dispatch<ranks>, num_tokens_per_rank, \
moe_recv_counter_mapped, num_tokens_per_expert, \
moe_recv_expert_counter_mapped, moe_recv_tokens_per_experts, \
num_experts, num_tokens, num_channels, is_token_in_rank, \
channel_prefix_matrix, rank_prefix_matrix_copy, num_memset_int, \
expert_alignment, buffer_ptrs, barrier_signal_ptrs, rank); \
break
constexpr int kNumThreads = 128;
EP_HOST_ASSERT(num_experts % num_ranks == 0);
EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads);
SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream);
SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
#undef NOTIFY_DISPATCH_LAUNCH_CASE
}
template <int kNumRanks>
__global__ void cached_notify_dispatch(const int *rank_prefix_matrix, int num_memset_int,
void **buffer_ptrs, int **barrier_signal_ptrs, int rank) {
// A simplified version for cached handles
barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);
// Copy and clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto ptr = static_cast<int *>(buffer_ptrs[rank]);
#pragma unroll
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
ptr[i] = rank_prefix_matrix[i];
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[kNumRanks * kNumRanks + i] = 0;
// Barrier after cleaning
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
}
void cached_notify_dispatch(const int *rank_prefix_matrix, int num_memset_int, void **buffer_ptrs,
int **barrier_signal_ptrs, int rank, int num_ranks,
hipStream_t stream) {
#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, cached_notify_dispatch<ranks>, rank_prefix_matrix, \
num_memset_int, buffer_ptrs, barrier_signal_ptrs, rank); \
break
SETUP_LAUNCH_CONFIG(1, 256, stream);
SWITCH_RANKS(CACHED_NOTIFY_DISPATCH_LAUNCH_CASE);
#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE
}
template <int kNumRanks, int kNumThreads>
__global__ void __launch_bounds__(kNumThreads, 1)
dispatch(int4 *recv_x, float *recv_x_scales, int *recv_src_idx, int64_t *recv_topk_idx,
float *recv_topk_weights, int *recv_channel_offset, int *send_head, const int4 *x,
const float *x_scales, const int64_t *topk_idx, const float *topk_weights,
const bool *is_token_in_rank, const int *channel_prefix_matrix, int num_tokens,
int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride, void **buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
const bool is_sender = sm_id % 2 == 0;
EP_DEVICE_ASSERT(num_sms % 2 == 0);
// Several warps are response for a single rank
const auto num_threads_per_rank = kNumThreads / kNumRanks;
const auto num_channels = num_sms / 2;
const auto responsible_rank = (static_cast<int>(thread_id)) / num_threads_per_rank;
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
const auto responsible_channel = sm_id / 2;
int num_experts_per_rank = num_experts / kNumRanks;
EP_DEVICE_ASSERT(num_experts_per_rank > 0 or num_topk == 0);
EP_DEVICE_ASSERT(num_topk <= kWarpSize);
EP_DEVICE_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
EP_DEVICE_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));
// Calculate pointers by the specific layout
// `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int)
auto ptr = reinterpret_cast<void *>(
static_cast<int8_t *>(buffer_ptrs[is_sender ? responsible_rank : rank]) +
kNumRanks * kNumRanks * sizeof(int));
int target_rank = is_sender ? rank : responsible_rank;
auto num_channels_total = num_channels * kNumRanks;
auto channel_rank_offset = responsible_channel * kNumRanks + target_rank;
// Channel buffer metadata
// Senders are responsible for tails, and receivers are responsible for heads
// Stored on the receiver side
// The retired signals are actually boolean flags, but to align with 16 bytes, we make it
// `int64_t` `start_offset`: kNumChannels * kNumRanks * sizeof(int) `end_offset`: kNumChannels *
// kNumRanks * sizeof(int) `head_idx`: kNumChannels * kNumRanks * sizeof(int) `tail_idx`:
// kNumChannels * kNumRanks * sizeof(int)
auto channel_start_offset = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_end_offset = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_head_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_tail_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
// Channel data buffers, stored on the receiver side
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
// `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk *
// sizeof(int64_t) `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens *
// num_topk * sizeof(float) `x_scales_buffers`: kNumChannels * kNumRanks *
// num_recv_buffer_tokens * num_scales * sizeof(float)
auto channel_x_buffers =
Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4,
channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens,
channel_rank_offset * num_recv_buffer_tokens);
auto channel_topk_idx_buffers =
Buffer<int64_t>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk,
channel_rank_offset * num_recv_buffer_tokens * num_topk);
auto channel_topk_weights_buffers =
Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk,
channel_rank_offset * num_recv_buffer_tokens * num_topk);
auto channel_x_scales_buffers =
Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales,
channel_rank_offset * num_recv_buffer_tokens * num_scales);
if (is_sender) {
// Workers for sending
constexpr int num_send_warps = kNumThreads / kWarpSize;
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
const auto send_thread_id = thread_id;
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / kWarpSize;
EP_DEVICE_ASSERT(kNumRanks <= kWarpSize);
EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0);
// Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
// NOTES: this is for distinguishing zero tokens
if (lane_id == 0 and send_warp_id_in_rank == 0) {
int value = responsible_channel > 0
? channel_prefix_matrix[responsible_rank * num_channels +
responsible_channel - 1]
: 0;
st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1);
value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel];
st_relaxed_sys_global(channel_end_offset.buffer(), -value - 1);
}
syncwarp();
// Get tasks
int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, responsible_channel, token_start_idx,
token_end_idx);
// Iterate over all tokens and send by chunks
int cached_channel_tail_idx = 0;
for (int64_t token_idx = token_start_idx; token_idx < token_end_idx;) {
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
// NOTES: the head index received by different warps may not be the same
auto start_time = clock64();
while (lane_id == 0) {
// NOTES: we only consider the worst case, because counting the real numbers are
// time-consuming
int num_used_slots =
cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens)
break;
// Rare cases to loop again
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf(
"DeepEP timeout for dispatch senders, rank %d, responsible_channel = %d\n",
rank, responsible_channel);
trap();
}
}
syncwarp();
int chunk_token_idx = 0;
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
// NOTES: for the same token, the warp assigned to save `send_head` may be different
// from the warp assigned to send the following data
if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
send_head[token_idx * kNumRanks + responsible_rank] =
is_token_in_rank[token_idx * kNumRanks + responsible_rank]
? cached_channel_tail_idx
: -1;
// Skip if not selected
if (not is_token_in_rank[token_idx * kNumRanks + responsible_rank]) {
token_idx++;
continue;
}
// Get an empty slot
int dst_slot_idx = (cached_channel_tail_idx++) % num_recv_buffer_tokens;
if (cached_channel_tail_idx % num_send_warps_per_rank == send_warp_id_in_rank) {
// Copy data
auto shifted_channel_x_buffers =
channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x + token_idx * hidden_int4;
UNROLLED_WARP_COPY(2, lane_id, hidden_int4, shifted_channel_x_buffers,
shifted_x, __ldg, st_na_global);
// Copy source index
if (lane_id == 0)
channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);
// Copy `topk_idx` and `topk_weights` with transformed index
if (lane_id < num_topk) {
// Top-k index
int recv_expert_begin = responsible_rank * num_experts_per_rank,
recv_expert_end = (responsible_rank + 1) * num_experts_per_rank;
auto idx_value = __ldg(topk_idx + token_idx * num_topk + lane_id);
idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end)
? idx_value - recv_expert_begin
: -1;
channel_topk_idx_buffers[dst_slot_idx * num_topk + lane_id] = idx_value;
// Top-k weights
auto weight_value = __ldg(topk_weights + token_idx * num_topk + lane_id);
weight_value = (idx_value >= 0) ? weight_value : 0.0f;
channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] =
weight_value;
}
// Copy `x_scales`
#pragma unroll
for (int i = lane_id; i < num_scales; i += kWarpSize) {
auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
channel_x_scales_buffers[dst_slot_idx * num_scales + i] =
__ldg(x_scales + offset);
}
}
// Move token index
chunk_token_idx++, token_idx++;
}
// Move tail index
// NOTES: here all warps should share the same new tail
if (num_threads_per_rank > kWarpSize) {
__syncthreads();
} else {
syncwarp();
}
if (send_warp_id_in_rank == 0 and lane_id == 0)
st_relaxed_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx);
}
} else {
// Workers for receiving and copying into buffer
constexpr int num_recv_warps = kNumThreads / kWarpSize;
constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks;
const auto recv_thread_id = thread_id;
const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank;
const auto recv_warp_id_in_rank = recv_thread_id_in_rank / kWarpSize;
EP_DEVICE_ASSERT(kNumRanks <= kWarpSize);
EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0);
// Calculate offset first
auto rank_prefix_matrix = static_cast<int *>(buffer_ptrs[rank]);
int rank_offset = responsible_rank > 0
? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank]
: 0;
// Receive channel offset
int total_offset, num_tokens_to_recv;
while (lane_id == 0 and
(total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0)
;
while (lane_id == 0 and
(num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0)
;
if (lane_id == 0) {
total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1;
if (recv_warp_id_in_rank == 0)
recv_channel_offset[responsible_rank * num_channels + responsible_channel] =
total_offset;
num_tokens_to_recv -= total_offset;
}
total_offset = shfl_sync(total_offset, 0);
total_offset += rank_offset;
num_tokens_to_recv = shfl_sync(num_tokens_to_recv, 0);
// Shared tail indices for different warps
__shared__ volatile int shared_channel_tail_idx[kNumRanks];
auto start_time = clock64();
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
while (num_tokens_to_recv > 0) {
// NOTES: unlike the sender, the receiver must ensure that the tail indices hold by
// different warps are the same
while (recv_thread_id_in_rank == 0) {
cached_channel_tail_idx = ld_relaxed_sys_global(channel_tail_idx.buffer());
// Ready to copy
if (cached_channel_head_idx != cached_channel_tail_idx) {
shared_channel_tail_idx[responsible_rank] = cached_channel_tail_idx;
break;
}
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP timeout for dispatch receivers, rank %d, responsible_channel = "
"%d, tokens remained: %d\n",
rank, responsible_channel, num_tokens_to_recv);
trap();
}
}
// Synchronize queue tail
if (num_threads_per_rank > kWarpSize) {
__syncthreads();
} else {
syncwarp();
}
cached_channel_tail_idx = shared_channel_tail_idx[responsible_rank];
// Copy data
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
for (int chunk_idx = recv_warp_id_in_rank; chunk_idx < num_recv_tokens;
chunk_idx += num_recv_warps_per_rank) {
int token_idx_in_buffer =
(cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
auto shifted_buffer_x_int4 =
channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
auto shifted_recv_x_int4 =
recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
UNROLLED_WARP_COPY(2, lane_id, hidden_int4, shifted_recv_x_int4,
shifted_buffer_x_int4, ld_nc_global, st_na_global);
}
// Copy `src_idx`
#pragma unroll 4
for (int chunk_idx = cached_channel_head_idx + recv_thread_id_in_rank;
chunk_idx < cached_channel_tail_idx;
chunk_idx += kWarpSize * num_recv_warps_per_rank)
recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] = ld_nc_global(
channel_src_idx_buffers.buffer() + chunk_idx % num_recv_buffer_tokens);
// Copy `topk_idx` and `topk_weights`
#pragma unroll 4
for (int idx = recv_thread_id_in_rank; idx < num_recv_tokens * num_topk;
idx += kWarpSize * num_recv_warps_per_rank) {
int chunk_idx = idx / num_topk, token_topk_idx = idx % num_topk;
int token_idx_in_buffer =
(cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
auto recv_idx =
static_cast<int64_t>(total_offset + chunk_idx) * num_topk + token_topk_idx;
auto buffer_idx = token_idx_in_buffer * num_topk + token_topk_idx;
recv_topk_idx[recv_idx] =
ld_nc_global(channel_topk_idx_buffers.buffer() + buffer_idx);
recv_topk_weights[recv_idx] =
ld_nc_global(channel_topk_weights_buffers.buffer() + buffer_idx);
}
// Copy `x_scales`
#pragma unroll 4
for (int i = recv_thread_id_in_rank; i < num_recv_tokens * num_scales;
i += kWarpSize * num_recv_warps_per_rank) {
int chunk_idx = i / num_scales, scales_idx = i % num_scales;
int token_idx_in_buffer =
(cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
recv_x_scales[static_cast<int64_t>(total_offset + chunk_idx) * num_scales +
scales_idx] =
ld_nc_global(channel_x_scales_buffers.buffer() +
token_idx_in_buffer * num_scales + scales_idx);
}
// Move queue
cached_channel_head_idx += num_recv_tokens;
total_offset += num_recv_tokens;
if (num_threads_per_rank > kWarpSize) {
__syncthreads();
} else {
syncwarp();
}
if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and lane_id == 0)
st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx);
// Exit
num_tokens_to_recv -= num_recv_tokens;
}
}
// Clean unused `recv_topk_idx` as -1
if (num_worst_tokens > 0) {
auto rank_prefix_matrix = static_cast<int *>(buffer_ptrs[rank]);
const auto num_recv_tokens = rank_prefix_matrix[(kNumRanks - 1) * kNumRanks + rank];
const auto clean_start = num_recv_tokens * num_topk + sm_id * kNumThreads;
const auto clean_end = num_worst_tokens * num_topk;
const auto clean_stride = num_sms * kNumThreads;
#pragma unroll
for (int i = clean_start + thread_id; i < clean_end; i += clean_stride)
recv_topk_idx[i] = -1;
}
}
void dispatch(void *recv_x, float *recv_x_scales, int *recv_src_idx, int64_t *recv_topk_idx,
float *recv_topk_weights, int *recv_channel_offset, int *send_head, const void *x,
const float *x_scales, const int64_t *topk_idx, const float *topk_weights,
const bool *is_token_in_rank, const int *channel_prefix_matrix, int num_tokens,
int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride, void **buffer_ptrs, int rank,
int num_ranks, hipStream_t stream, int num_sms, int num_max_send_tokens,
int num_recv_buffer_tokens) {
constexpr int kNumThreads = 1024;
// Make sure never OOB
EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride <
std::numeric_limits<int>::max());
#define DISPATCH_LAUNCH_CASE(ranks) \
{ \
LAUNCH_KERNEL_NON_COOPERATIVE( \
&cfg, dispatch<ranks, kNumThreads>, reinterpret_cast<int4 *>(recv_x), recv_x_scales, \
recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, send_head, \
reinterpret_cast<const int4 *>(x), x_scales, topk_idx, topk_weights, is_token_in_rank, \
channel_prefix_matrix, num_tokens, num_worst_tokens, hidden_int4, num_topk, \
num_experts, num_scales, scale_token_stride, scale_hidden_stride, buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
} \
break
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
EP_HOST_ASSERT(num_sms % 2 == 0);
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
SWITCH_RANKS(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}
template <int kNumRanks>
__global__ void cached_notify_combine(void **buffer_ptrs, int *send_head, int num_channels,
int num_recv_tokens, int num_memset_int,
int **barrier_signal_ptrs, int rank) {
const auto sm_id = static_cast<int>(blockIdx.x);
if (sm_id == 0) {
// Barrier before cleaning
barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);
// Clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto ptr = static_cast<int *>(buffer_ptrs[rank]);
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[i] = 0;
// Barrier after cleaning
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
} else {
const auto channel_id = sm_id - 1;
const auto thread_id = static_cast<int>(threadIdx.x);
const auto rank_id = thread_id / kWarpSize;
const auto lane_id = thread_id % kWarpSize;
if (rank_id >= kNumRanks)
return;
int token_start_idx, token_end_idx;
get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx,
token_end_idx);
// NOTES: `1 << 25` is a heuristic large number
int last_head = 1 << 25;
#pragma unroll
for (int token_idx_tail = token_end_idx - 1; token_idx_tail >= token_start_idx;
token_idx_tail -= kWarpSize) {
int token_idx = token_idx_tail - lane_id, expected_head = 0;
auto current_head = (token_idx >= token_start_idx)
? __ldg(send_head + token_idx * kNumRanks + rank_id)
: -1;
for (int i = 0; i < min(kWarpSize, token_idx_tail - token_start_idx + 1); ++i) {
const int head = shfl_sync(current_head, i);
if (head < 0) {
if (lane_id == i)
expected_head = -last_head - 1;
} else {
last_head = head;
}
}
if (current_head < 0 and token_idx >= token_start_idx)
send_head[token_idx * kNumRanks + rank_id] = expected_head;
}
}
}
void cached_notify_combine(void **buffer_ptrs, int *send_head, int num_channels,
int num_recv_tokens, int num_memset_int, int **barrier_signal_ptrs,
int rank, int num_ranks, hipStream_t stream) {
#define CACHED_NOTIFY_COMBINE(ranks) \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, cached_notify_combine<ranks>, buffer_ptrs, send_head, \
num_channels, num_recv_tokens, num_memset_int, \
barrier_signal_ptrs, rank); \
break
const int num_threads = ::max(128, kWarpSize * num_ranks);
EP_HOST_ASSERT(num_ranks <= num_threads);
EP_HOST_ASSERT(num_threads <= 1024);
EP_HOST_ASSERT(1 + num_channels <= num_channels * 2);
SETUP_LAUNCH_CONFIG(1 + num_channels, num_threads, stream);
SWITCH_RANKS(CACHED_NOTIFY_COMBINE);
#undef CACHED_NOTIFY_COMBINE
}
template <typename dtype_t, int kNumRanks, int kNumThreads>
__global__ void __launch_bounds__(kNumThreads, 1)
combine(dtype_t *recv_x, float *recv_topk_weights, const dtype_t *x, const float *topk_weights,
const dtype_t *bias_0, const dtype_t *bias_1, const int *src_idx,
const int *rank_prefix_matrix, const int *channel_prefix_matrix, int *send_head,
int num_tokens, int num_recv_tokens, int hidden, int num_topk, void **buffer_ptrs,
int rank, int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto sm_id = static_cast<int>(blockIdx.x), lane_id = get_lane_id();
const auto num_channels = num_sms / 2;
const bool is_sender = sm_id % 2 == 0;
const int responsible_channel = sm_id / 2;
EP_DEVICE_ASSERT(num_topk <= kWarpSize);
constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4);
auto x_int4 = reinterpret_cast<const int4 *>(x);
auto bias_0_int4 = reinterpret_cast<const int4 *>(bias_0);
auto bias_1_int4 = reinterpret_cast<const int4 *>(bias_1);
auto recv_int4 = reinterpret_cast<int4 *>(recv_x);
if (is_sender) {
// Workers for sending
// Several warps are responsible for a single rank
constexpr int num_send_warps_per_rank = (kNumThreads / kWarpSize) / kNumRanks;
constexpr int num_send_warps = num_send_warps_per_rank * kNumRanks;
const auto num_threads_per_rank = num_send_warps_per_rank * kWarpSize;
const auto send_thread_id = thread_id;
const auto send_warp_id = send_thread_id / kWarpSize;
const auto send_rank_id = (responsible_channel + send_warp_id) % kNumRanks;
const auto send_warp_id_in_rank = send_warp_id / kNumRanks;
EP_STATIC_ASSERT(num_send_warps * kWarpSize == kNumThreads, "Invalid warp count");
// Calculate pointers by the specific layout
auto ptr = reinterpret_cast<void *>(static_cast<int8_t *>(buffer_ptrs[send_rank_id]));
auto num_channels_total = num_channels * kNumRanks;
auto channel_rank_offset = responsible_channel * kNumRanks + rank;
// Channel meta data
// `head_idx`: kNumChannels * kNumRanks * sizeof(int)
// `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 *
// sizeof(int4) `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens *
// sizeof(int) `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens *
// num_topk * sizeof(float)
auto channel_head_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_tail_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_x_buffers =
Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4,
channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens,
channel_rank_offset * num_recv_buffer_tokens);
auto channel_topk_weights_buffers =
Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk,
channel_rank_offset * num_recv_buffer_tokens * num_topk);
// Get tasks
// NOTES: `channel_offset` is already shifted
int rank_offset =
send_rank_id > 0 ? rank_prefix_matrix[(send_rank_id - 1) * kNumRanks + rank] : 0;
int num_rank_tokens = rank_prefix_matrix[send_rank_id * kNumRanks + rank] - rank_offset;
int channel_offset =
channel_prefix_matrix[send_rank_id * num_channels + responsible_channel];
int num_channel_tokens =
(responsible_channel == num_channels - 1
? num_rank_tokens
: channel_prefix_matrix[send_rank_id * num_channels + responsible_channel + 1]) -
channel_offset;
int token_start_idx = rank_offset + channel_offset,
token_end_idx = rank_offset + channel_offset + num_channel_tokens;
// Iterate over all tokens and send by chunks
int current_channel_tail_idx = 0;
for (int64_t token_idx = token_start_idx; token_idx < token_end_idx;) {
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
auto start_time = wall_clock64();
int num_round_tokens =
min(num_max_send_tokens, token_end_idx - static_cast<int>(token_idx));
while (lane_id == 0) {
// NOTES: we only consider the worst case, because counting the real numbers are
// time-consuming
int num_used_slots =
current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens)
break;
// Rare cases to loop again
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf(
"DeepEP timeout for combine senders, rank %d, responsible_channel = %d\n",
rank, responsible_channel);
trap();
}
}
syncwarp();
// Send by chunk
#pragma unroll
for (int i = send_warp_id_in_rank; i < num_round_tokens; i += num_send_warps_per_rank) {
// Get an empty slot
int dst_slot_idx = (current_channel_tail_idx + i) % num_recv_buffer_tokens;
// Copy data
auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x_int4 + (token_idx + i) * hidden_int4;
UNROLLED_WARP_COPY(2, lane_id, hidden_int4, shifted_x_buffers, shifted_x,
ld_nc_global, st_na_global);
// Send source index
if (lane_id == 0)
channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i);
// Send `topk_weights`
if (num_topk > 0 and lane_id < num_topk)
channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] =
__ldg(topk_weights + (token_idx + i) * num_topk + lane_id);
}
token_idx += num_round_tokens;
current_channel_tail_idx += num_round_tokens;
// Move tail index
if (num_threads_per_rank > kWarpSize) {
__syncthreads();
} else {
syncwarp();
}
if (lane_id == 0 and send_warp_id_in_rank == 0)
st_relaxed_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx);
}
} else {
// Workers for receiving
// One warp for moving the queue head, others for reduction
constexpr int num_recv_warps = kNumThreads / kWarpSize;
const auto recv_warp_id = thread_id / kWarpSize;
EP_DEVICE_ASSERT(kNumRanks <= kWarpSize and kNumThreads > kWarpSize);
EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % kWarpSize == 0);
// Shared head, tail and retired flags for receiver warps
__shared__ volatile int warp_channel_head_idx[num_recv_warps][kNumRanks];
__shared__ volatile int channel_tail_idx[kNumRanks];
__shared__ volatile bool warp_retired[num_recv_warps];
if (thread_id < num_recv_warps)
warp_retired[thread_id] = false;
if (lane_id < kNumRanks)
warp_channel_head_idx[recv_warp_id][lane_id] = 0;
if (thread_id < kNumRanks)
channel_tail_idx[thread_id] = 0;
__syncthreads();
if (thread_id < kWarpSize) {
int *channel_head_idx_ptr =
static_cast<int *>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
int *channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
// Queue head updater
int last_head = 0;
while (lane_id < kNumRanks) {
// Check retired
bool retired = true;
#pragma unroll
for (int i = 1; i < num_recv_warps; ++i)
retired = retired and warp_retired[i];
if (retired)
break;
// Update queue tail
channel_tail_idx[lane_id] = ld_relaxed_sys_global(channel_tail_idx_ptr);
// Update minimum head
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for (int i = 1; i < num_recv_warps; ++i)
if (not warp_retired[i])
min_head = min(min_head, warp_channel_head_idx[i][lane_id]);
if (min_head != std::numeric_limits<int>::max() and min_head > last_head)
st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head);
}
} else {
// Receivers
// Channel metadata
// All lanes will use data buffer, but only rank lane will use `head/tail/src_idx`
Buffer<int4> channel_x_buffers[kNumRanks];
Buffer<float> channel_topk_weights_buffers[kNumRanks];
// Calculate pointers by the specific layout
#pragma unroll
for (int i = 0; i < kNumRanks; ++i) {
auto channel_rank_offset = responsible_channel * kNumRanks + i;
auto num_channels_total = num_channels * kNumRanks;
// `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
auto ptr = reinterpret_cast<void *>(static_cast<int8_t *>(buffer_ptrs[rank]) +
2 * num_channels * kNumRanks * sizeof(int));
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 *
// sizeof(int4)
channel_x_buffers[i] =
Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4,
channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens *
// sizeof(int)
ptr = reinterpret_cast<void *>(static_cast<int8_t *>(ptr) +
num_channels_total * num_recv_buffer_tokens *
sizeof(int));
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens *
// num_topk * sizeof(float)
channel_topk_weights_buffers[i] =
Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk,
channel_rank_offset * num_recv_buffer_tokens * num_topk);
}
// The same tokens as the dispatch process
int token_start_idx, token_end_idx;
get_channel_task_range(num_recv_tokens, num_channels, responsible_channel,
token_start_idx, token_end_idx);
// Iterate over all tokens and combine
for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx;
token_idx += num_recv_warps - 1) {
// Read expected head
int expected_head = -1;
if (lane_id < kNumRanks)
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + lane_id);
auto start_time = wall_clock64();
while (__any(channel_tail_idx[lane_id] <= expected_head and expected_head >= 0)) {
// Timeout check
if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP timeout for combine receivers, rank %d, responsible_channel "
"= %d, expect = %d\n",
rank, responsible_channel, expected_head);
trap();
}
}
syncwarp();
// Broadcast current heads
int num_topk_ranks = 0, topk_ranks[kNumRanks], slot_indices[kNumRanks];
#pragma unroll
for (int i = 0; i < kNumRanks; ++i) {
auto expected_head_i = __shfl(expected_head, i);
if (expected_head_i >= 0) {
slot_indices[num_topk_ranks] = expected_head_i % num_recv_buffer_tokens;
topk_ranks[num_topk_ranks++] = i;
}
}
// Reduce data
#pragma unroll
for (int i = lane_id; i < hidden_int4; i += kWarpSize) {
// Read bias
// TODO: make it as a template
int4 bias_0_value_int4 = bias_0_int4 != nullptr
? __ldg(bias_0_int4 + token_idx * hidden_int4 + i)
: make_int4(0, 0, 0, 0);
int4 bias_1_value_int4 = bias_1_int4 != nullptr
? __ldg(bias_1_int4 + token_idx * hidden_int4 + i)
: make_int4(0, 0, 0, 0);
float values[kDtypePerInt4] = {0};
auto bias_0_values = reinterpret_cast<const dtype_t *>(&bias_0_value_int4);
auto bias_1_values = reinterpret_cast<const dtype_t *>(&bias_1_value_int4);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++j)
values[j] = static_cast<float>(bias_0_values[j]) +
static_cast<float>(bias_1_values[j]);
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++j) {
int4 recv_value = __ldg(channel_x_buffers[topk_ranks[j]].buffer() +
slot_indices[j] * hidden_int4 + i);
const dtype_t *recv_dtypes = reinterpret_cast<const dtype_t *>(&recv_value);
#pragma unroll
for (int k = 0; k < kDtypePerInt4; ++k)
values[k] += static_cast<float>(recv_dtypes[k]);
}
// Cast back to `dtype_t`
int4 out_int4;
auto out_dtypes = reinterpret_cast<dtype_t *>(&out_int4);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++j)
out_dtypes[j] = static_cast<dtype_t>(values[j]);
recv_int4[token_idx * hidden_int4 + i] = out_int4;
}
// Reduce `topk_weights`
if (lane_id < num_topk) {
float value = 0;
#pragma unroll
for (int i = 0; i < num_topk_ranks; ++i)
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() +
slot_indices[i] * num_topk + lane_id);
recv_topk_weights[token_idx * num_topk + lane_id] = value;
}
// Update head
if (lane_id < kNumRanks)
warp_channel_head_idx[recv_warp_id][lane_id] =
(expected_head < 0) ? -expected_head - 1 : expected_head + 1;
}
// Retired
syncwarp();
if (lane_id == 0)
warp_retired[recv_warp_id] = true;
}
}
}
void combine(hipDataType type, void *recv_x, float *recv_topk_weights, const void *x,
const float *topk_weights, const void *bias_0, const void *bias_1, const int *src_idx,
const int *rank_prefix_matrix, const int *channel_prefix_matrix, int *send_head,
int num_tokens, int num_recv_tokens, int hidden, int num_topk, void **buffer_ptrs,
int rank, int num_ranks, hipStream_t stream, int num_sms, int num_max_send_tokens,
int num_recv_buffer_tokens) {
constexpr int kNumThreads = 1024;
#define COMBINE_LAUNCH_CASE(dtype, ranks) \
{ \
LAUNCH_KERNEL_NON_COOPERATIVE( \
&cfg, combine<dtype, ranks, kNumThreads>, reinterpret_cast<dtype *>(recv_x), \
recv_topk_weights, reinterpret_cast<const dtype *>(x), topk_weights, \
reinterpret_cast<const dtype *>(bias_0), reinterpret_cast<const dtype *>(bias_1), \
src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, num_tokens, \
num_recv_tokens, hidden, num_topk, buffer_ptrs, rank, num_max_send_tokens, \
num_recv_buffer_tokens); \
} \
break
#define COMBINE_DTYPE_LAUNCH_CASE(dtype) \
SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); \
break
// Even-numbered blocks for sending, odd-numbered blocks for receiving
EP_HOST_ASSERT(num_sms % 2 == 0);
EP_HOST_ASSERT(kNumThreads >= num_ranks * kWarpSize);
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
SWITCH_TYPES(COMBINE_DTYPE_LAUNCH_CASE);
#undef COMBINE_DTYPE_LAUNCH_CASE
#undef COMBINE_LAUNCH_CASE
}
} // namespace intranode
} // namespace deep_ep
#include "hip/hip_runtime.h"
#pragma once #pragma once
#include "configs.cuh" #include "configs.cuh"
......
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#pragma once
#include "configs.cuh"
#include "exception.cuh"
// ROCm helper functions and structures
namespace rocm::experimental {
typedef struct {
dim3 num_sms;
dim3 num_threads;
unsigned int shared_mem_bytes;
hipStream_t stream;
} hipLaunchConfig_t;
// Compile time void** kernelArgs array fill with variadic arguments
template <typename T> void fill_kernel_args(void **f, size_t idx, T &&arg) {
f[idx] = static_cast<void *>(std::addressof(arg));
}
template <typename Head, typename... Tail>
void fill_kernel_args(void **f, size_t idx, Head &&head, Tail &&...tail) {
f[idx] = static_cast<void *>(std::addressof(head));
fill_kernel_args(f, idx + 1, std::forward<Tail>(tail)...);
}
} // namespace rocm::experimental
#ifndef SETUP_LAUNCH_CONFIG
// The code below is a workaround for ROCm. All the proposed overhead
// is to match current macro signatures and should be reworked once
// cudaLaunchKernelExt() hip alternative is live.
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
rocm::experimental::hipLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream};
#endif // #ifndef SETUP_LAUNCH_CONFIG
#ifndef LAUNCH_KERNEL
template <typename T, typename Kern, typename... Args>
inline void LAUNCH_KERNEL(T &&config, Kern &&kernel, Args &&...args) {
constexpr size_t k_num_kernel_args = sizeof...(args);
void *kernel_args[k_num_kernel_args];
rocm::experimental::fill_kernel_args(kernel_args, 0, std::forward<Args>(args)...);
CUDA_CHECK(hipLaunchCooperativeKernel(std::forward<Kern>(kernel), config->num_sms,
config->num_threads, kernel_args,
config->shared_mem_bytes, config->stream));
}
template <typename T, typename Kern, typename... Args>
inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...args) {
hipLaunchKernelGGL((*kernel), dim3(config->num_sms), dim3(config->num_threads), config->shared_mem_bytes, config->stream,
std::forward<Args>(args)...);
}
#endif // #ifndef LAUNCH_KERNEL
#define SWITCH_RANKS(case_macro) \
switch (num_ranks) { \
case 2: \
case_macro(2); \
case 4: \
case_macro(4); \
case 8: \
case_macro(8); \
default: \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
} \
while (false)
#define SWITCH_RDMA_RANKS(case_macro) \
switch (num_ranks / NUM_MAX_NVL_PEERS) { \
case 2: \
case_macro(2); \
case 3: \
case_macro(3); \
case 4: \
case_macro(4); \
case 8: \
case_macro(8); \
case 16: \
case_macro(16); \
case 18: \
case_macro(18); \
case 20: \
case_macro(20); \
default: \
EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
} \
while (false)
#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \
switch (num_ranks) { \
case 2: \
case_macro(dtype, 2); \
case 4: \
case_macro(dtype, 4); \
case 8: \
case_macro(dtype, 8); \
default: \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
} \
while (false)
#define SWITCH_TYPES(case_macro) \
switch (type) { \
case HIP_R_16BF: \
case_macro(hip_bfloat16); \
case HIP_R_32F: \
case_macro(float); \
default: \
EP_HOST_ASSERT(false and "Unsupported type"); \
} \
while (false)
#define SWITCH_HIDDEN(case_macro) \
switch (hidden) { \
case 2560: \
case_macro(2560); \
case 5120: \
case_macro(5120); \
case 4096: \
case_macro(4096); \
case 7168: \
case_macro(7168); \
default: \
EP_HOST_ASSERT(false and "Unsupported hidden"); \
} \
while (false)
#include "hip/hip_runtime.h"
#include <cstring> #include <cstring>
#include "configs.cuh" #include "configs.cuh"
......
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#include <cstring>
#include "configs.cuh"
#include "exception.cuh"
#include "launch_hip.cuh"
#include "utils_hip.cuh"
#ifndef DISABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
#endif
namespace deep_ep {
namespace intranode {
template <int kNumRanks>
__global__ void barrier(int **barrier_signal_ptrs, int rank) {
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
}
void barrier(int **barrier_signal_ptrs, int rank, int num_ranks, hipStream_t stream) {
#define BARRIER_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, barrier<ranks>, barrier_signal_ptrs, rank); \
break
SETUP_LAUNCH_CONFIG(1, kWarpSize, stream);
SWITCH_RANKS(BARRIER_LAUNCH_CASE);
#undef BARRIER_LAUNCH_CASE
}
} // namespace intranode
namespace internode {
#ifndef DISABLE_ROCSHMEM
rocshmem::rocshmem_team_t cpu_rdma_team = rocshmem::ROCSHMEM_TEAM_INVALID;
rocshmem::rocshmem_team_config_t cpu_rdma_team_config;
std::vector<uint8_t> get_unique_id() {
rocshmem::rocshmem_uniqueid_t unique_id;
rocshmem::rocshmem_get_uniqueid(&unique_id);
std::vector<uint8_t> result(sizeof(rocshmem::rocshmem_uniqueid_t));
std::memcpy(result.data(), &unique_id, sizeof(rocshmem::rocshmem_uniqueid_t));
return result;
}
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks,
bool low_latency_mode) {
rocshmem::rocshmem_uniqueid_t root_unique_id;
rocshmem::rocshmem_init_attr_t attr;
std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(rocshmem::rocshmem_uniqueid_t));
rocshmem::rocshmem_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
rocshmem::rocshmem_init_attr(rocshmem::ROCSHMEM_INIT_WITH_UNIQUEID, &attr);
// Create sub-RDMA teams
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {
EP_HOST_ASSERT(cpu_rdma_team == rocshmem::ROCSHMEM_TEAM_INVALID);
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(rocshmem::rocshmem_team_split_strided(
rocshmem::ROCSHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS,
NUM_MAX_NVL_PEERS, num_ranks / NUM_MAX_NVL_PEERS,
&cpu_rdma_team_config, 0, &cpu_rdma_team) == 0);
EP_HOST_ASSERT(cpu_rdma_team != rocshmem::ROCSHMEM_TEAM_INVALID);
}
rocshmem::rocshmem_barrier_all();
return rocshmem::rocshmem_my_pe();
}
void *alloc(size_t size, size_t alignment) {
auto alloc_size = ALIGN(size, alignment);
return rocshmem::rocshmem_malloc(alloc_size);
}
void free(void *ptr) {
rocshmem::rocshmem_free(ptr);
}
void barrier() {
rocshmem::rocshmem_barrier_all();
}
void finalize() {
if (cpu_rdma_team != rocshmem::ROCSHMEM_TEAM_INVALID) {
rocshmem::rocshmem_team_destroy(cpu_rdma_team);
cpu_rdma_team = rocshmem::ROCSHMEM_TEAM_INVALID;
}
rocshmem::rocshmem_finalize();
}
#endif
} // namespace internode
} // namespace deep_ep
#include "hip/hip_runtime.h"
#pragma once #pragma once
#include "configs.cuh" #include "configs.cuh"
#include "exception.cuh" #include "exception.cuh"
...@@ -194,8 +195,7 @@ __device__ __forceinline__ int64_t ld_volatile_global(const volatile uint64_t *p ...@@ -194,8 +195,7 @@ __device__ __forceinline__ int64_t ld_volatile_global(const volatile uint64_t *p
return ret; return ret;
} }
template <typename dtype_t> template <typename dtype_t> __device__ __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) {
__device__ __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) {
using T = typename VecInt<sizeof(dtype_t)>::vec_t; using T = typename VecInt<sizeof(dtype_t)>::vec_t;
auto ret = __builtin_nontemporal_load(reinterpret_cast<const T *>(ptr)); auto ret = __builtin_nontemporal_load(reinterpret_cast<const T *>(ptr));
return *reinterpret_cast<dtype_t *>(&ret); return *reinterpret_cast<dtype_t *>(&ret);
......
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#pragma once
#include "configs.cuh"
#include "exception.cuh"
#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \
constexpr int kLoopStride = kWarpSize * (UNROLL_FACTOR); \
typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type \
unrolled_values[(UNROLL_FACTOR)]; \
auto __src = (SRC); \
auto __dst = (DST); \
for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \
_Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) \
unrolled_values[__j] = LD_FUNC(__src + __i + __j * kWarpSize); \
_Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) \
ST_FUNC(__dst + __i + __j * kWarpSize, unrolled_values[__j]); \
} \
{ \
int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); \
_Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) { \
if (__i + __j * kWarpSize < (N)) { \
unrolled_values[__j] = LD_FUNC(__src + __i + __j * kWarpSize); \
} \
} \
_Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) { \
if (__i + __j * kWarpSize < (N)) { \
ST_FUNC(__dst + __i + __j * kWarpSize, unrolled_values[__j]); \
} \
} \
} \
}
#define UNROLLED_WARP_COPY_EMULATED(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \
constexpr int kLoopStride = kEmulatedWarpSize * (UNROLL_FACTOR); \
typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type \
unrolled_values[(UNROLL_FACTOR)]; \
auto __src = (SRC); \
auto __dst = (DST); \
for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \
_Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) \
unrolled_values[__j] = LD_FUNC(__src + __i + __j * kEmulatedWarpSize); \
_Pragma("unroll") for (int __j = 0; __j < (UNROLL_FACTOR); ++__j) \
ST_FUNC(__dst + __i + __j * kEmulatedWarpSize, unrolled_values[__j]); \
} \
for (int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); \
__i += kEmulatedWarpSize) \
ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \
}
// HELPER FUNCTIONS
// #####################################################################################
template <typename T>
__device__ __forceinline__ T shfl_xor(const T val, int laneMask, int width = kWarpSize,
uint64_t shfl_sync_mask = kFullWarpMask) {
return __shfl_xor(val, laneMask, width);
}
__device__ __forceinline__ int
shfl_sync(const int val, int srcLane = 0, int width = kWarpSize,
uint64_t shfl_sync_mask = kFullWarpMask) { // Let compiler deduce type
return __shfl(val, srcLane, width);
}
__device__ __forceinline__ int __any_sync(uint64_t mask, int predicate) {
uint64_t predicate_bit_pattern = __ballot(predicate);
return (predicate_bit_pattern & mask) > 0;
}
__device__ __forceinline__ int __all_sync(uint64_t mask, int predicate) {
uint64_t predicate_bit_pattern = __ballot(predicate);
return (~predicate_bit_pattern & mask) == 0;
}
__device__ __forceinline__ void syncwarp() {
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
__builtin_amdgcn_wave_barrier();
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
}
// ######################################################################################################
namespace deep_ep {
template <int kBytes> struct VecInt {};
template <> struct VecInt<1> {
using vec_t = int8_t;
};
template <> struct VecInt<2> {
using vec_t = int16_t;
};
template <> struct VecInt<4> {
using vec_t = int;
};
template <> struct VecInt<8> {
using vec_t = int64_t;
};
template <> struct VecInt<16> {
using native_int4 = int __attribute__((ext_vector_type(4)));
using vec_t = native_int4;
};
__device__ __forceinline__ void trap() {
abort();
}
__device__ __forceinline__ void memory_fence() {
__threadfence_system();
}
__device__ __forceinline__ void memory_fence_gpu() {
__threadfence();
}
__device__ __forceinline__ void memory_fence_cta() {
__threadfence_block();
}
__device__ __forceinline__ void st_relaxed_sys_global(int *ptr, int val) {
__builtin_nontemporal_store(val, ptr);
}
__device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
__hip_atomic_store(const_cast<int *>(ptr), val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_SYSTEM);
}
__device__ __forceinline__ void st_release_cta(const int *ptr, int val) {
__hip_atomic_store(const_cast<int *>(ptr), val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_WORKGROUP);
}
__device__ __forceinline__ int ld_relaxed_sys_global(const int *ptr) {
int res = __builtin_nontemporal_load(ptr);
return res;
}
__device__ __forceinline__ int ld_relaxed_sys_global(const uint64_t *ptr) {
uint64_t ret;
ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
return ret;
}
__device__ __forceinline__ int ld_acquire_sys_global(const int *ptr) {
int ret;
ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_SYSTEM);
return ret;
}
__device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t *ptr) {
uint64_t ret;
ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_SYSTEM);
return ret;
}
__device__ __forceinline__ int ld_acquire_global(const int *ptr) {
int ret;
ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT);
return ret;
}
__device__ __forceinline__ int atomic_add_release_global(const int *ptr, int value) {
int ret;
// ret = __hip_atomic_fetch_add(const_cast<int *>(ptr), value, __ATOMIC_RELEASE,
// __HIP_MEMORY_SCOPE_AGENT);
ret = atomicAdd((int*)ptr, value);
return ret;
}
__device__ __forceinline__ int ld_acquire_cta(const int *ptr) {
int ret;
ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_WORKGROUP);
return ret;
}
__device__ __forceinline__ int ld_volatile_global(const volatile int *ptr) {
int ret;
ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
return ret;
}
__device__ __forceinline__ float ld_volatile_global(const volatile float *ptr) {
float ret;
ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
return ret;
}
__device__ __forceinline__ int64_t ld_volatile_global(const volatile int64_t *ptr) {
int64_t ret;
ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
return ret;
}
__device__ __forceinline__ int64_t ld_volatile_global(const volatile uint64_t *ptr) {
int64_t ret;
ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
return ret;
}
template <typename dtype_t> __device__ __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) {
using T = typename VecInt<sizeof(dtype_t)>::vec_t;
auto ret = __builtin_nontemporal_load(reinterpret_cast<const T *>(ptr));
return *reinterpret_cast<dtype_t *>(&ret);
}
////////////////// used in ibgda
__device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) {
uint8_t *non_const_ptr = const_cast<uint8_t *>(ptr);
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
__device__ __forceinline__ void st_na_relaxed(const uint16_t *ptr, uint16_t val) {
uint16_t *non_const_ptr = const_cast<uint16_t *>(ptr);
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
__device__ __forceinline__ void st_na_relaxed(const uint32_t *ptr, uint32_t val) {
uint32_t *non_const_ptr = const_cast<uint32_t *>(ptr);
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
__device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) {
int *non_const_ptr = const_cast<int *>(ptr);
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
__device__ __forceinline__ void st_na_relaxed(const int4 *ptr, int4 val) {
int4 *non_const_ptr = const_cast<int4 *>(ptr);
non_const_ptr->x = val.x;
non_const_ptr->y = val.y;
non_const_ptr->z = val.z;
non_const_ptr->w = val.w;
}
__device__ __forceinline__ void st_na_release(const int *ptr, int val) {
int *non_const_ptr = const_cast<int *>(ptr);
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
__device__ __forceinline__ void st_na_release(const uint32_t *ptr, uint32_t val) {
uint32_t *non_const_ptr = const_cast<uint32_t *>(ptr);
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
}
__device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) {
uint64_t *non_const_ptr = const_cast<uint64_t *>(ptr);
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
}
// TODO:: apply "st.global.L1::no_allocate" in ROCM
template <typename dtype_t>
__device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t &value) {
st_na_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t *>(ptr),
*reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t *>(&value));
}
template <> __device__ __forceinline__ void st_na_global(const int *ptr, const int &value) {
int *non_const_ptr = const_cast<int *>(ptr);
*non_const_ptr = value;
}
template <> __device__ __forceinline__ void st_na_global(const int64_t *ptr, const int64_t &value) {
int64_t *non_const_ptr = const_cast<int64_t *>(ptr);
*non_const_ptr = value;
}
template <> __device__ __forceinline__ void st_na_global(const float *ptr, const float &value) {
float *non_const_ptr = const_cast<float *>(ptr);
*non_const_ptr = value;
}
template <> __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4 &value) {
int4 *non_const_ptr = const_cast<int4 *>(ptr);
*non_const_ptr = value;
}
__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id,
int &token_start_idx, int &token_end_idx) {
int num_tokens_per_sm = DIVUP(num_tokens, num_sms);
token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens);
token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);
}
template <typename dtype_t>
__device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, "");
auto send_int_values = reinterpret_cast<int *>(&ptr);
int recv_int_values[sizeof(dtype_t) / sizeof(int)];
#pragma unroll
for (int i = 0; i < sizeof(dtype_t) / sizeof(int); ++i)
recv_int_values[i] = shfl_sync(send_int_values[i], src_lane_idx);
return *reinterpret_cast<dtype_t *>(recv_int_values);
}
__forceinline__ __device__ int warp_reduce_sum(int value) {
if constexpr (kWarpSize == 64)
value += shfl_xor<int>(value, 32);
value += shfl_xor<int>(value, 16);
value += shfl_xor<int>(value, 8);
value += shfl_xor<int>(value, 4);
value += shfl_xor<int>(value, 2);
value += shfl_xor<int>(value, 1);
return value;
}
__forceinline__ __device__ int get_lane_id() {
int lane_id = threadIdx.x % kWarpSize;
return lane_id;
}
template <int kNumRanks, bool kSyncOnly = false>
__forceinline__ __device__ void barrier_block(int **barrier_signal_ptrs, int rank) {
auto thread_id = static_cast<int>(threadIdx.x);
// For non-sync-only cases, the memory operations by other threads in the block must be visible
// to the `sys` scope
if constexpr (not kSyncOnly) {
memory_fence();
__syncthreads();
}
// Add self-ranks, sub other ranks
if (thread_id < kNumRanks) {
atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG);
atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG);
}
EP_DEVICE_ASSERT(kNumRanks <= blockDim.x);
// Check timeout
auto start_time = clock64();
while (true) {
auto value =
thread_id < kNumRanks ? ld_volatile_global(barrier_signal_ptrs[rank] + thread_id) : 0;
if (__all_sync(kFullWarpMask, value <= 0))
break;
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and thread_id < kNumRanks) {
printf("DeepEP timeout check failed: rank = %d, thread = %d, value = %d)\n", rank,
thread_id, value);
trap();
}
}
__syncthreads();
}
} // namespace deep_ep
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment