Commit 5563b6d0 authored by lijian6's avatar lijian6
Browse files

Fitter for DCU.


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent da6ca24e
export OMPI_MCA_pml=ucx
export OMPI_MCA_osc=ucx
export OMPI_MCA_coll_hcoll_enable=0
export UCX_TLS=rc,rocm
# export ROCSHMEM_UNIQUEID_WITH_MPI=1
export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export ROCSHMEM_MAX_NUM_CONTEXTS=32
export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_4:1,mlx5_6:1,mlx5_8:1
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCSHMEM_HEAP_SIZE=10737418240
export PYTHONPATH=/work/Tmp/DeepEP:$PYTHONPATH
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/internode_lj.py
export OMPI_MCA_pml=ucx
export OMPI_MCA_osc=ucx
export OMPI_MCA_coll_hcoll_enable=0
export UCX_TLS=rc,rocm
# export ROCSHMEM_UNIQUEID_WITH_MPI=1
export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export ROCSHMEM_MAX_NUM_CONTEXTS=32
export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_4:1,mlx5_6:1,mlx5_8:1
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCSHMEM_HEAP_SIZE=10737418240
export PYTHONPATH=/work/Tmp/DeepEP:$PYTHONPATH
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/internode_lj.py
#!/bin/bash
set -eux
# if [ ! -d "build_" ]; then
# mkdir -p build_
# fi
# PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])")
# PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])")
/opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/intranode.hip -o build_/intranode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/runtime.hip -o build_/runtime.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/layout.cu -o build_/layout.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/deep_ep.hip -o build_/deep_ep.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/internode.hip -o build_/internode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o -L/work/Tmp/DeepEP/rocshmem_dir/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/work/Tmp/DeepEP/rocshmem_dir/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so.5.2.25211.1469-8d6b0397 /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L/usr/local/lib/python3.10/dist-packages/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/lib/ -libverbs -lmlx5
# build whl
echo "Using Python: $(which python3)"
python3 --version
python setup.py bdist_wheel
echo "✅ Build complete:"
ls -lh dist/
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/intranode.hip -o build_/intranode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/runtime.hip -o build_/runtime.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/layout.cu -o build_/layout.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/deep_ep.hip -o build_/deep_ep.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/internode.hip -o build_/internode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o -L/work/Tmp/DeepEP/rocshmem_dir/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o aaa.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,aaa.so -Wl,-rpath,/work/Tmp/DeepEP/rocshmem_dir/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so.5.2.25211.1469-8d6b0397 /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L/usr/local/lib/python3.10/dist-packages/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/lib/ -libverbs -lmlx5
#pragma once #pragma once
#include "kernels/api.cuh" #include "./kernels/api.cuh"
#include "./kernels/configs.cuh"
#include "kernels/exception.cuh" #include "kernels/exception.cuh"
namespace deep_ep { namespace deep_ep {
template <typename dtype_t>
dtype_t ceil_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;
}
template <typename dtype_t>
dtype_t align_up(dtype_t a, dtype_t b) {
return ceil_div<dtype_t>(a, b) * b;
}
template <typename dtype_t>
dtype_t align_down(dtype_t a, dtype_t b) {
return a / b * b;
}
struct Config { struct Config {
int num_sms; int num_sms;
int num_max_nvl_chunked_send_tokens; int num_max_nvl_chunked_send_tokens;
...@@ -27,77 +13,91 @@ struct Config { ...@@ -27,77 +13,91 @@ struct Config {
int num_max_rdma_chunked_send_tokens; int num_max_rdma_chunked_send_tokens;
int num_max_rdma_chunked_recv_tokens; int num_max_rdma_chunked_recv_tokens;
Config(int num_sms, Config(int num_sms, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens)
int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) : : num_sms(num_sms), num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),
num_sms(num_sms), num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),
num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens), num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens), num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
EP_HOST_ASSERT(num_sms >= 0); EP_HOST_ASSERT(num_sms >= 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0); EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and
num_max_nvl_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens); EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0); EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and
num_max_rdma_chunked_recv_tokens > 0);
// Ceil up RDMA buffer size // Ceil up RDMA buffer size
this->num_max_rdma_chunked_recv_tokens = align_up<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens); this->num_max_rdma_chunked_recv_tokens =
ALIGN<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens); EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
// NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2); // have space to push
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <=
num_max_rdma_chunked_recv_tokens / 2);
} }
size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const { size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
// Below are some assumptions // Below are some assumptions
// TODO: add assertions // TODO: add assertions
constexpr int kNumMaxTopK = 128; constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128; constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0); EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / 2; const int num_channels = num_sms / 2;
size_t num_bytes = 0; size_t num_bytes = 0;
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int); num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes; num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
#ifndef DISABLE_NVSHMEM #ifndef DISABLE_ROCSHMEM
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
internode::get_source_meta_bytes();
#endif #endif
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float); sizeof(int64_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
kNumMaxScales * sizeof(float);
num_bytes = ((num_bytes + 127) / 128) * 128; num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes; return num_bytes;
} }
size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const { size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
#ifndef DISABLE_NVSHMEM #ifndef DISABLE_ROCSHMEM
// Legacy mode // Legacy mode
if (num_ranks <= NUM_MAX_NVL_PEERS) if (num_ranks <= NUM_MAX_NVL_PEERS)
return 0; return 0;
// Below are some assumptions // Below are some assumptions
// TODO: add assertions // TODO: add assertions
constexpr int kNumMaxTopK = 128; constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128; constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % 2 == 0); EP_HOST_ASSERT(num_sms % 2 == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms / 2; const int num_channels = num_sms / 2;
size_t num_bytes = 0; size_t num_bytes = 0;
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2; num_bytes +=
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2; num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2; internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxScales * sizeof(float) * 2;
num_bytes +=
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
num_bytes = ((num_bytes + 127) / 128) * 128; num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes; return num_bytes;
#else #else
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation"); EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install "
"rocSHMEM by following docs/install_dependencies.md");
#endif #endif
} }
}; };
...@@ -105,33 +105,35 @@ struct Config { ...@@ -105,33 +105,35 @@ struct Config {
struct LowLatencyBuffer { struct LowLatencyBuffer {
int num_clean_int = 0; int num_clean_int = 0;
void* dispatch_rdma_send_buffer = nullptr; void *dispatch_rdma_send_buffer = nullptr;
void* dispatch_rdma_recv_data_buffer = nullptr; void *dispatch_rdma_recv_data_buffer = nullptr;
int* dispatch_rdma_recv_count_buffer = nullptr; int *dispatch_rdma_recv_count_buffer = nullptr;
void* combine_rdma_send_buffer = nullptr; void *combine_rdma_send_buffer = nullptr;
void* combine_rdma_recv_data_buffer = nullptr; void *combine_rdma_recv_data_buffer = nullptr;
int* combine_rdma_recv_flag_buffer = nullptr; int *combine_rdma_recv_flag_buffer = nullptr;
void* combine_rdma_send_buffer_data_start = nullptr; void *combine_rdma_send_buffer_data_start = nullptr;
size_t num_bytes_per_combine_msg = 0; size_t num_bytes_per_combine_msg = 0;
std::pair<int*, int> clean_meta() { std::pair<int *, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int}; return {dispatch_rdma_recv_count_buffer, num_clean_int};
} }
}; };
struct LowLatencyLayout { struct LowLatencyLayout {
size_t total_bytes = 0; size_t total_bytes = 0;
LowLatencyBuffer buffers[2]; LowLatencyBuffer buffers[2];
template <typename out_ptr_t = void*, typename count_ptr_t = uint8_t*, typename in_ptr_t = void*> template <typename out_ptr_t = void *, typename count_ptr_t = uint8_t *,
out_ptr_t advance(const in_ptr_t& ptr, size_t count) { typename in_ptr_t = void *>
out_ptr_t advance(const in_ptr_t &ptr, size_t count) {
return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count); return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
} }
LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) {
const int num_scales = hidden / 128; const int num_scales = hidden / 128;
// Dispatch and combine layout: // Dispatch and combine layout:
...@@ -140,56 +142,69 @@ struct LowLatencyLayout { ...@@ -140,56 +142,69 @@ struct LowLatencyLayout {
// - 2 symmetric odd/even signaling buffers // - 2 symmetric odd/even signaling buffers
// Message sizes // Message sizes
// NOTES: you should add a control `int4` for combine messages if you want to do data transformation // NOTES: you should add a control `int4` for combine messages if you want to do data
// NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-128-channel min/max // transformation
EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); EP_HOST_ASSERT(num_scales * sizeof(float) <= static_cast<size_t>(hidden));
size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); size_t num_bytes_per_dispatch_msg =
size_t num_bytes_per_combine_msg = num_scales * sizeof(nv_bfloat162) + hidden * sizeof(nv_bfloat16); sizeof(int4) +
std::max(hidden * sizeof(hip_bfloat16), hidden + num_scales * sizeof(float));
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16);
// Send buffer // Send buffer
size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; size_t dispatch_send_buffer_bytes =
size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_send_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes); size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0); EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);
total_bytes += send_buffer_bytes * 2; total_bytes += send_buffer_bytes * 2;
// Symmetric receive buffers // Symmetric receive buffers
// TODO: optimize memory usages // TODO: optimize memory usages
size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; size_t dispatch_recv_data_buffer_bytes =
size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); size_t combine_recv_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t recv_buffer_bytes =
std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
total_bytes += recv_buffer_bytes * 2; total_bytes += recv_buffer_bytes * 2;
// Symmetric signaling buffers // Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); size_t signaling_buffer_bytes =
size_t signaling_buffer_bytes_aligned = align_up<size_t>(signaling_buffer_bytes, 128); std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
total_bytes += signaling_buffer_bytes_aligned * 2; total_bytes += signaling_buffer_bytes_aligned * 2;
// Assign pointers // Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer, // NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// so you may see some parameters are duplicated // so you may see some parameters are duplicated
for (int i = 0; i < 2; ++ i) { for (int i = 0; i < 2; ++i) {
buffers[i] = { buffers[i] = {
static_cast<int>(signaling_buffer_bytes / sizeof(int)), static_cast<int>(signaling_buffer_bytes / sizeof(int)),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
advance<int*>(rdma_buffer, signaling_buffer_bytes_aligned * i), recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
advance<int*>(rdma_buffer, signaling_buffer_bytes_aligned * i), recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
num_bytes_per_combine_msg num_bytes_per_combine_msg};
};
} }
} }
}; };
size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; int num_ranks, int num_experts) {
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; auto num_bytes =
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
.total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
NUM_BUFFER_ALIGNMENT_BYTES;
} }
} // namespace deep_ep } // namespace deep_ep
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#pragma once
#include "kernels/api.cuh"
#include "kernels/configs.cuh"
#include "kernels/exception.cuh"
namespace deep_ep {
struct Config {
int num_sms;
int num_max_nvl_chunked_send_tokens;
int num_max_nvl_chunked_recv_tokens;
int num_max_rdma_chunked_send_tokens;
int num_max_rdma_chunked_recv_tokens;
Config(int num_sms, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens)
: num_sms(num_sms), num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),
num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),
num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
EP_HOST_ASSERT(num_sms >= 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and
num_max_nvl_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and
num_max_rdma_chunked_recv_tokens > 0);
// Ceil up RDMA buffer size
this->num_max_rdma_chunked_recv_tokens =
ALIGN<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
// NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always
// have space to push
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <=
num_max_rdma_chunked_recv_tokens / 2);
}
size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
#ifndef DISABLE_ROCSHMEM
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
internode::get_source_meta_bytes();
#endif
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
sizeof(int64_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
kNumMaxScales * sizeof(float);
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}
size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
#ifndef DISABLE_ROCSHMEM
// Legacy mode
if (num_ranks <= NUM_MAX_NVL_PEERS)
return 0;
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % 2 == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes +=
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxScales * sizeof(float) * 2;
num_bytes +=
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
#else
EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install "
"rocSHMEM by following docs/install_dependencies.md");
#endif
}
};
struct LowLatencyBuffer {
int num_clean_int = 0;
void *dispatch_rdma_send_buffer = nullptr;
void *dispatch_rdma_recv_data_buffer = nullptr;
int *dispatch_rdma_recv_count_buffer = nullptr;
void *combine_rdma_send_buffer = nullptr;
void *combine_rdma_recv_data_buffer = nullptr;
int *combine_rdma_recv_flag_buffer = nullptr;
void *combine_rdma_send_buffer_data_start = nullptr;
size_t num_bytes_per_combine_msg = 0;
std::pair<int *, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int};
}
};
struct LowLatencyLayout {
size_t total_bytes = 0;
LowLatencyBuffer buffers[2];
template <typename out_ptr_t = void *, typename count_ptr_t = uint8_t *,
typename in_ptr_t = void *>
out_ptr_t advance(const in_ptr_t &ptr, size_t count) {
return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
}
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) {
const int num_scales = hidden / 128;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even receive buffers
// - 2 symmetric odd/even signaling buffers
// Message sizes
// NOTES: you should add a control `int4` for combine messages if you want to do data
// transformation
EP_HOST_ASSERT(num_scales * sizeof(float) <= static_cast<size_t>(hidden));
size_t num_bytes_per_dispatch_msg =
sizeof(int4) +
std::max(hidden * sizeof(hip_bfloat16), hidden + num_scales * sizeof(float));
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16);
// Send buffer
size_t dispatch_send_buffer_bytes =
num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_send_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);
total_bytes += send_buffer_bytes * 2;
// Symmetric receive buffers
// TODO: optimize memory usages
size_t dispatch_recv_data_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_recv_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t recv_buffer_bytes =
std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
total_bytes += recv_buffer_bytes * 2;
// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes =
std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
total_bytes += signaling_buffer_bytes_aligned * 2;
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// so you may see some parameters are duplicated
for (int i = 0; i < 2; ++i) {
buffers[i] = {
static_cast<int>(signaling_buffer_bytes / sizeof(int)),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
num_bytes_per_combine_msg};
}
}
};
inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) {
auto num_bytes =
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
.total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
NUM_BUFFER_ALIGNMENT_BYTES;
}
} // namespace deep_ep
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#pragma once #pragma once
// Forcibly disable NDEBUG
#ifdef NDEBUG
#undef NDEBUG
#endif
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/pytypes.h> #include <pybind11/pytypes.h>
#include <torch/types.h> #include <torch/types.h>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "./kernels/configs.cuh"
#include "kernels/exception.cuh"
#include "config.hpp" #include "config.hpp"
#include "event.hpp" #include "event.hpp"
#include "kernels/configs.cuh"
#include "kernels/exception.cuh"
#ifndef TORCH_EXTENSION_NAME
#define TORCH_EXTENSION_NAME deep_ep_cpp
#endif
namespace deep_ep { namespace deep_ep {
...@@ -27,27 +18,27 @@ struct Buffer { ...@@ -27,27 +18,27 @@ struct Buffer {
private: private:
// Low-latency mode buffer // Low-latency mode buffer
int low_latency_buffer_idx = 0; int low_latency_buffer_idx = 0;
bool low_latency_mode = false; bool low_latency_mode = false;
// NVLink Buffer // NVLink Buffer
int64_t num_nvl_bytes; int64_t num_nvl_bytes;
void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; void *buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void** buffer_ptrs_gpu = nullptr; void **buffer_ptrs_gpu = nullptr;
// NVSHMEM Buffer // NVSHMEM Buffer
int64_t num_rdma_bytes; int64_t num_rdma_bytes;
void* rdma_buffer_ptr = nullptr; void *rdma_buffer_ptr = nullptr;
// Device info and communication // Device info and communication
int device_id; int device_id;
int num_device_sms; int num_device_sms;
int rank, rdma_rank, nvl_rank; int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks; int num_ranks, num_rdma_ranks, num_nvl_ranks;
cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; hipIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
// Stream for communication // Stream for communication
at::cuda::CUDAStream comm_stream; at::hip::HIPStreamMasqueradingAsCUDA comm_stream;
// After IPC/NVSHMEM synchronization, this flag will be true // After IPC/NVSHMEM synchronization, this flag will be true
bool available = false; bool available = false;
...@@ -58,26 +49,29 @@ private: ...@@ -58,26 +49,29 @@ private:
bool destroyed = false; bool destroyed = false;
// Barrier signals // Barrier signals
int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; int *barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int** barrier_signal_ptrs_gpu = nullptr; int **barrier_signal_ptrs_gpu = nullptr;
// Workspace // Workspace
void* workspace = nullptr; void *workspace = nullptr;
// Host-side MoE info // Host-side MoE info
volatile int* moe_recv_counter = nullptr; volatile int *moe_recv_counter = nullptr;
int* moe_recv_counter_mapped = nullptr; int *moe_recv_counter_mapped = nullptr;
// Host-side expert-level MoE info // Host-side expert-level MoE info
volatile int* moe_recv_expert_counter = nullptr; volatile int *moe_recv_expert_counter = nullptr;
int* moe_recv_expert_counter_mapped = nullptr; int *moe_recv_expert_counter_mapped = nullptr;
// Host-side RDMA-level MoE info // Host-side RDMA-level MoE info
volatile int* moe_recv_rdma_counter = nullptr; volatile int *moe_recv_rdma_counter = nullptr;
int* moe_recv_rdma_counter_mapped = nullptr; int *moe_recv_rdma_counter_mapped = nullptr;
bool use_default_stream_as_comm_stream = false;
public: public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy); Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool use_default_stream_as_comm_stream);
~Buffer() noexcept(false); ~Buffer() noexcept(false);
...@@ -97,70 +91,102 @@ public: ...@@ -97,70 +91,102 @@ public:
pybind11::bytearray get_local_nvshmem_unique_id() const; pybind11::bytearray get_local_nvshmem_unique_id() const;
torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; torch::Tensor get_local_buffer_tensor(const pybind11::object &dtype, int64_t offset,
bool use_rdma_buffer) const;
torch::Stream get_comm_stream() const; torch::Stream get_comm_stream() const;
void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt); void sync(const std::vector<int> &device_ids,
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
const std::optional<pybind11::bytearray> &root_unique_id_opt);
void destroy(); void destroy();
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event, std::optional<EventHandle>>
bool async, bool allocate_on_comm_stream); get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
std::optional<EventHandle> &previous_event, bool async,
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>> bool allocate_on_comm_stream);
intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights, std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
int expert_alignment, int num_worst_tokens, const Config& config, std::optional<EventHandle>>
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream); intranode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx,
const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank,
const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert,
int cached_num_recv_tokens,
const std::optional<torch::Tensor> &cached_rank_prefix_matrix,
const std::optional<torch::Tensor> &cached_channel_prefix_matrix,
int expert_alignment, int num_worst_tokens, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights, intranode_combine(const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1, const std::optional<torch::Tensor> &bias_0,
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, const std::optional<torch::Tensor> &bias_1, const torch::Tensor &src_idx,
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream); const torch::Tensor &rank_prefix_matrix,
const torch::Tensor &channel_prefix_matrix, const torch::Tensor &send_head,
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<EventHandle>> const Config &config, std::optional<EventHandle> &previous_event, bool async,
internode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales, bool allocate_on_comm_stream);
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const std::optional<torch::Tensor>& num_tokens_per_rdma_rank, std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>,
torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx,
const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank,
const std::optional<torch::Tensor> &num_tokens_per_rdma_rank,
const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum, const std::optional<torch::Tensor> &cached_rdma_channel_prefix_matrix,
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum, const std::optional<torch::Tensor> &cached_recv_rdma_rank_prefix_sum,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream); const std::optional<torch::Tensor> &cached_gbl_channel_prefix_matrix,
const std::optional<torch::Tensor> &cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights, internode_combine(
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1, const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, const std::optional<torch::Tensor> &bias_0, const std::optional<torch::Tensor> &bias_1,
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, const torch::Tensor &src_meta, const torch::Tensor &is_combined_token_in_rank,
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, const torch::Tensor &rdma_channel_prefix_matrix, const torch::Tensor &rdma_rank_prefix_sum,
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream); const torch::Tensor &gbl_channel_prefix_matrix, const torch::Tensor &combined_rdma_head,
const torch::Tensor &combined_nvl_head, const Config &config,
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden,
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, int num_experts);
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats, std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
int num_max_dispatch_tokens_per_rank, int num_experts, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
bool use_fp8, bool round_scale, bool use_ue8m0, low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_idx,
bool async, bool return_recv_hook); const std::optional<torch::Tensor> &cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor> &dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8,
bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, low_latency_combine(const torch::Tensor &x, const torch::Tensor &topk_idx,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor &topk_weights, const torch::Tensor &src_info,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats, const torch::Tensor &layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts, const std::optional<torch::Tensor> &combine_wait_recv_cost_stats,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt,
const std::optional<torch::Tensor>& out = std::nullopt); bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor> &out = std::nullopt);
torch::Tensor
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const; torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank,
int hidden, int num_experts) const;
}; };
} // namespace deep_ep } // namespace deep_ep
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/types.h>
#include <tuple>
#include <vector>
#include "kernels/configs.cuh"
#include "kernels/exception.cuh"
#include "config_hip.hpp"
#include "event.hpp"
namespace deep_ep {
struct Buffer {
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8");
private:
// Low-latency mode buffer
int low_latency_buffer_idx = 0;
bool low_latency_mode = false;
// NVLink Buffer
int64_t num_nvl_bytes;
void *buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void **buffer_ptrs_gpu = nullptr;
// NVSHMEM Buffer
int64_t num_rdma_bytes;
void *rdma_buffer_ptr = nullptr;
// Device info and communication
int device_id;
int num_device_sms;
int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks;
hipIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
// Stream for communication
at::hip::HIPStreamMasqueradingAsCUDA comm_stream;
// After IPC/NVSHMEM synchronization, this flag will be true
bool available = false;
// Whether explicit `destroy()` is required.
bool explicitly_destroy;
// After `destroy()` be called, this flag will be true
bool destroyed = false;
// Barrier signals
int *barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int **barrier_signal_ptrs_gpu = nullptr;
// Workspace
void *workspace = nullptr;
// Host-side MoE info
volatile int *moe_recv_counter = nullptr;
int *moe_recv_counter_mapped = nullptr;
// Host-side expert-level MoE info
volatile int *moe_recv_expert_counter = nullptr;
int *moe_recv_expert_counter_mapped = nullptr;
// Host-side RDMA-level MoE info
volatile int *moe_recv_rdma_counter = nullptr;
int *moe_recv_rdma_counter_mapped = nullptr;
bool use_default_stream_as_comm_stream = false;
public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool use_default_stream_as_comm_stream);
~Buffer() noexcept(false);
bool is_available() const;
bool is_internode_available() const;
int get_num_rdma_ranks() const;
int get_rdma_rank() const;
int get_root_rdma_rank(bool global) const;
int get_local_device_id() const;
pybind11::bytearray get_local_ipc_handle() const;
pybind11::bytearray get_local_nvshmem_unique_id() const;
torch::Tensor get_local_buffer_tensor(const pybind11::object &dtype, int64_t offset,
bool use_rdma_buffer) const;
torch::Stream get_comm_stream() const;
void sync(const std::vector<int> &device_ids,
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
const std::optional<pybind11::bytearray> &root_unique_id_opt);
void destroy();
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
std::optional<EventHandle>>
get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::optional<EventHandle>>
intranode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx,
const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank,
const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert,
int cached_num_recv_tokens,
const std::optional<torch::Tensor> &cached_rank_prefix_matrix,
const std::optional<torch::Tensor> &cached_channel_prefix_matrix,
int expert_alignment, int num_worst_tokens, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
intranode_combine(const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &bias_0,
const std::optional<torch::Tensor> &bias_1, const torch::Tensor &src_idx,
const torch::Tensor &rank_prefix_matrix,
const torch::Tensor &channel_prefix_matrix, const torch::Tensor &send_head,
const Config &config, std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>,
torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx,
const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank,
const std::optional<torch::Tensor> &num_tokens_per_rdma_rank,
const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor> &cached_rdma_channel_prefix_matrix,
const std::optional<torch::Tensor> &cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor> &cached_gbl_channel_prefix_matrix,
const std::optional<torch::Tensor> &cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_combine(
const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &bias_0, const std::optional<torch::Tensor> &bias_1,
const torch::Tensor &src_meta, const torch::Tensor &is_combined_token_in_rank,
const torch::Tensor &rdma_channel_prefix_matrix, const torch::Tensor &rdma_rank_prefix_sum,
const torch::Tensor &gbl_channel_prefix_matrix, const torch::Tensor &combined_rdma_head,
const torch::Tensor &combined_nvl_head, const Config &config,
std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream);
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden,
int num_experts);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_idx,
const std::optional<torch::Tensor> &cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor> &dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8,
bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor &x, const torch::Tensor &topk_idx,
const torch::Tensor &topk_weights, const torch::Tensor &src_info,
const torch::Tensor &layout_range,
const std::optional<torch::Tensor> &combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor> &out = std::nullopt);
torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank,
int hidden, int num_experts) const;
};
} // namespace deep_ep
#include <ATen/cuda/CUDAContext.h> #pragma once
#include <memory>
#include <ATen/hip/HIPContext.h>
#include "kernels/exception.cuh" #include "kernels/exception.cuh"
namespace deep_ep { namespace deep_ep {
...@@ -10,33 +10,34 @@ struct EventHandle { ...@@ -10,33 +10,34 @@ struct EventHandle {
EventHandle() { EventHandle() {
event = std::make_shared<torch::Event>(torch::kCUDA); event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(at::cuda::getCurrentCUDAStream()); event->record(at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
} }
explicit EventHandle(const at::cuda::CUDAStream& stream) { explicit EventHandle(const at::hip::HIPStreamMasqueradingAsCUDA &stream) {
event = std::make_shared<torch::Event>(torch::kCUDA); event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(stream); event->record(stream);
} }
EventHandle(const EventHandle& other) = default; EventHandle(const EventHandle &other) = default;
void current_stream_wait() const { void current_stream_wait() const {
at::cuda::getCurrentCUDAStream().unwrap().wait(*event); at::hip::getCurrentHIPStreamMasqueradingAsCUDA().unwrap().wait(*event);
} }
}; };
torch::Event create_event(const at::cuda::CUDAStream &s) { inline torch::Event create_event(const at::hip::HIPStreamMasqueradingAsCUDA &s) {
auto event = torch::Event(torch::kCUDA); auto event = torch::Event(torch::kCUDA);
event.record(s); event.record(s);
return event; return event;
} }
void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) { inline void stream_wait(const at::hip::HIPStreamMasqueradingAsCUDA &s_0,
const at::hip::HIPStreamMasqueradingAsCUDA &s_1) {
EP_HOST_ASSERT(s_0.id() != s_1.id()); EP_HOST_ASSERT(s_0.id() != s_1.id());
s_0.unwrap().wait(create_event(s_1)); s_0.unwrap().wait(create_event(s_1));
} }
void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) { inline void stream_wait(const at::hip::HIPStreamMasqueradingAsCUDA &s, const EventHandle &event) {
s.unwrap().wait(*event.event); s.unwrap().wait(*event.event);
} }
......
...@@ -15,7 +15,6 @@ add_deep_ep_library(runtime_cuda runtime.cu) ...@@ -15,7 +15,6 @@ add_deep_ep_library(runtime_cuda runtime.cu)
add_deep_ep_library(layout_cuda layout.cu) add_deep_ep_library(layout_cuda layout.cu)
add_deep_ep_library(intranode_cuda intranode.cu) add_deep_ep_library(intranode_cuda intranode.cu)
add_deep_ep_library(internode_cuda internode.cu) add_deep_ep_library(internode_cuda internode.cu)
add_deep_ep_library(internode_ll_cuda internode_ll.cu)
# Later, we should link all libraries in `EP_CUDA_LIBRARIES` # Later, we should link all libraries in `EP_CUDA_LIBRARIES`
set(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda internode_ll_cuda PARENT_SCOPE) set(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda PARENT_SCOPE)
#pragma once #pragma once
#include <hip/hip_runtime.h>
#include <vector> #include <vector>
#include "configs.cuh" #include "configs.cuh"
...@@ -9,7 +10,7 @@ namespace deep_ep { ...@@ -9,7 +10,7 @@ namespace deep_ep {
// Intranode runtime // Intranode runtime
namespace intranode { namespace intranode {
void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); void barrier(int **barrier_signal_ptrs, int rank, int num_ranks, hipStream_t stream);
} // namespace intranode } // namespace intranode
...@@ -18,7 +19,8 @@ namespace internode { ...@@ -18,7 +19,8 @@ namespace internode {
std::vector<uint8_t> get_unique_id(); std::vector<uint8_t> get_unique_id();
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode); int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks,
bool low_latency_mode);
void *alloc(size_t size, size_t alignment); void *alloc(size_t size, size_t alignment);
...@@ -33,49 +35,46 @@ void finalize(); ...@@ -33,49 +35,46 @@ void finalize();
// Layout kernels // Layout kernels
namespace layout { namespace layout {
void get_dispatch_layout(const topk_idx_t* topk_idx, void get_dispatch_layout(const int64_t *topk_idx, int *num_tokens_per_rank,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, int *num_tokens_per_rdma_rank, int *num_tokens_per_expert,
int* num_tokens_per_expert, bool* is_token_in_rank, bool *is_token_in_rank, int num_tokens, int num_topk, int num_ranks,
int num_tokens, int num_topk, int num_ranks, int num_experts, int num_experts, hipStream_t stream);
cudaStream_t stream);
} // namespace layout } // namespace layout
// Intranode kernels // Intranode kernels
namespace intranode { namespace intranode {
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, int64_t *moe_num_recv_tokens_per_experts, int num_experts, int num_tokens,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, const bool *is_token_in_rank, int *channel_prefix_matrix,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int *rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
cudaStream_t stream, int num_sms); void **buffer_ptrs, int **barrier_signal_ptrs, int rank, hipStream_t stream,
int num_sms);
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks, void cached_notify_dispatch(const int *rank_prefix_matrix, int num_memset_int, void **buffer_ptrs,
cudaStream_t stream); int **barrier_signal_ptrs, int rank, int num_ranks, hipStream_t stream);
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, topk_idx_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, void dispatch(void *recv_x, float *recv_x_scales, int *recv_src_idx, int64_t *recv_topk_idx,
int* send_head, const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights, float *recv_topk_weights, int *recv_channel_offset, int *send_head, const void *x,
const bool* is_token_in_rank, const int* channel_prefix_matrix, const float *x_scales, const int64_t *topk_idx, const float *topk_weights,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, const bool *is_token_in_rank, const int *channel_prefix_matrix, int num_tokens,
int scale_token_stride, int scale_hidden_stride, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank, int num_ranks, int scale_token_stride, int scale_hidden_stride, void **buffer_ptrs, int rank,
cudaStream_t stream, int num_sms, int num_ranks, hipStream_t stream, int num_sms, int num_max_send_tokens,
int num_max_send_tokens, int num_recv_buffer_tokens); int num_recv_buffer_tokens);
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, void cached_notify_combine(void **buffer_ptrs, int *send_head, int num_channels,
int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); int num_recv_tokens, int num_memset_int, int **barrier_signal_ptrs,
int rank, int num_ranks, hipStream_t stream);
void combine(cudaDataType_t type,
void* recv_x, float* recv_topk_weights, void combine(hipDataType type, void *recv_x, float *recv_topk_weights, const void *x,
const void* x, const float* topk_weights, const float *topk_weights, const void *bias_0, const void *bias_1, const int *src_idx,
const void* bias_0, const void* bias_1, const int *rank_prefix_matrix, const int *channel_prefix_matrix, int *send_head,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, int num_tokens, int num_recv_tokens, int hidden, int num_topk, void **buffer_ptrs,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, int rank, int num_ranks, hipStream_t stream, int num_sms, int num_max_send_tokens,
void** buffer_ptrs, int rank, int num_ranks, int num_recv_buffer_tokens);
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens);
} // namespace intranode } // namespace intranode
...@@ -84,89 +83,52 @@ namespace internode { ...@@ -84,89 +83,52 @@ namespace internode {
int get_source_meta_bytes(); int get_source_meta_bytes();
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
const bool* is_token_in_rank, int num_tokens, int num_channels, int num_experts, const bool *is_token_in_rank, int num_tokens,
int hidden_int4, int num_scales, int num_topk, int expert_alignment, int num_channels, int hidden_int4, int num_scales, int num_topk,
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int expert_alignment, int *rdma_channel_prefix_matrix,
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, int *recv_rdma_rank_prefix_sum, int *gbl_channel_prefix_matrix,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, int *recv_gbl_rank_prefix_sum, void *rdma_buffer_ptr,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int** barrier_signal_ptrs, int rank, int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool low_latency_mode); bool low_latency_mode);
void dispatch(void* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights, void *recv_src_meta, const void *x, const float *x_scales, const int64_t *topk_idx,
int* send_rdma_head, int* send_nvl_head, const float *topk_weights, int *send_rdma_head, int *send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, int *recv_rdma_channel_prefix_matrix, int *recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, const int *rdma_channel_prefix_matrix, const int *recv_rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, const int *gbl_channel_prefix_matrix, const int *recv_gbl_rank_prefix_sum,
const bool* is_token_in_rank, const bool *is_token_in_rank, int num_tokens, int hidden_int4, int num_scales,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride,
int scale_token_stride, int scale_hidden_stride, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int rank, int num_ranks, bool is_cached_dispatch, int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels,
cudaStream_t stream, int num_channels, bool low_latency_mode); bool low_latency_mode);
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, int num_ranks, int num_channels, int num_combined_tokens,
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, int *combined_rdma_head, const int *rdma_channel_prefix_matrix,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int** barrier_signal_ptrs, int rank, cudaStream_t stream, int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
int64_t num_rdma_bytes, int64_t num_nvl_bytes, hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode); bool is_cached_dispatch, bool low_latency_mode);
void combine(cudaDataType_t type, void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
void* combined_x, float* combined_topk_weights, const bool *is_combined_token_in_rank, const void *x, const float *topk_weights,
const bool* is_combined_token_in_rank, const void *bias_0, const void *bias_1, const int *combined_rdma_head,
const void* x, const float* topk_weights, const int *combined_nvl_head, const void *src_meta,
const void* bias_0, const void* bias_1, const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
const int* combined_rdma_head, const int* combined_nvl_head, const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
int num_tokens, int num_combined_tokens, int hidden, int num_topk, int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode);
int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode);
} // namespace internode } // namespace internode
// Internode low-latency kernels
namespace internode_ll {
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
cudaStream_t stream);
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const topk_idx_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases);
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const topk_idx_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases, bool zero_copy);
} // namespace internode_ll
} // namespace deep_ep } // namespace deep_ep
...@@ -5,133 +5,131 @@ ...@@ -5,133 +5,131 @@
namespace deep_ep { namespace deep_ep {
template <typename dtype_t> template <typename dtype_t> struct Buffer {
struct Buffer {
private: private:
uint8_t* ptr; uint8_t *ptr;
public: public:
int total_bytes; int total_bytes;
__device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {} __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {}
__device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) { __device__ __forceinline__ Buffer(void *&gbl_ptr, int num_elems, int offset = 0) {
total_bytes = num_elems * sizeof(dtype_t); total_bytes = num_elems * sizeof(dtype_t);
ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t); ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + offset * sizeof(dtype_t);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes; gbl_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
} }
__device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) { __device__ __forceinline__ Buffer advance_also(void *&gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes; gbl_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
return *this; return *this;
} }
__device__ __forceinline__ dtype_t* buffer() { __device__ __forceinline__ dtype_t *buffer() { return reinterpret_cast<dtype_t *>(ptr); }
return reinterpret_cast<dtype_t*>(ptr);
}
__device__ __forceinline__ dtype_t& operator[](int idx) { __device__ __forceinline__ dtype_t &operator[](int idx) { return buffer()[idx]; }
return buffer()[idx];
}
}; };
template <typename dtype_t, int kNumRanks = 1> template <typename dtype_t, int kNumRanks = 1> struct AsymBuffer {
struct AsymBuffer {
private: private:
uint8_t* ptrs[kNumRanks]; uint8_t *ptrs[kNumRanks];
int num_bytes; int num_bytes;
public: public:
int total_bytes; int total_bytes;
__device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, __device__ __forceinline__ AsymBuffer(void *&gbl_ptr, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1, int offset = 0) { int sm_id = 0, int num_sms = 1, int offset = 0) {
EP_STATIC_ASSERT(kNumRanks == 1, ""); EP_STATIC_ASSERT(kNumRanks == 1, "");
num_bytes = num_elems * sizeof(dtype_t); num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks; int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms; total_bytes = per_channel_bytes * num_sms;
ptrs[0] = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; ptrs[0] =
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes; reinterpret_cast<uint8_t *>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
} }
__device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, __device__ __forceinline__ AsymBuffer(void **gbl_ptrs, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1, int offset = 0) { int sm_id = 0, int num_sms = 1, int offset = 0) {
EP_STATIC_ASSERT(kNumRanks > 1, ""); EP_STATIC_ASSERT(kNumRanks > 1, "");
num_bytes = num_elems * sizeof(dtype_t); num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks; int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms; total_bytes = per_channel_bytes * num_sms;
for (int i = 0; i < kNumRanks; ++ i) { for (int i = 0; i < kNumRanks; ++i) {
ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; ptrs[i] = reinterpret_cast<uint8_t *>(gbl_ptrs[i]) + per_channel_bytes * sm_id +
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes; num_bytes * offset;
gbl_ptrs[i] = reinterpret_cast<uint8_t *>(gbl_ptrs[i]) + total_bytes;
} }
} }
__device__ __forceinline__ void advance(int shift) { __device__ __forceinline__ void advance(int shift) {
#pragma unroll #pragma unroll
for (int i = 0; i < kNumRanks; ++ i) for (int i = 0; i < kNumRanks; ++i)
ptrs[i] = ptrs[i] + shift * sizeof(dtype_t); ptrs[i] = ptrs[i] + shift * sizeof(dtype_t);
} }
__device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) { __device__ __forceinline__ AsymBuffer advance_also(void *&gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes; gbl_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
return *this; return *this;
} }
template<int kNumAlsoRanks> template <int kNumAlsoRanks>
__device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) { __device__ __forceinline__ AsymBuffer advance_also(void **gbl_ptrs) {
for (int i = 0; i < kNumAlsoRanks; ++ i) for (int i = 0; i < kNumAlsoRanks; ++i)
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes; gbl_ptrs[i] = reinterpret_cast<uint8_t *>(gbl_ptrs[i]) + total_bytes;
return *this; return *this;
} }
__device__ __forceinline__ dtype_t* buffer(int idx = 0) { __device__ __forceinline__ dtype_t *buffer(int idx = 0) {
EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case"); EP_STATIC_ASSERT(kNumRanks == 1,
return reinterpret_cast<dtype_t*>(ptrs[0] + num_bytes * idx); "`buffer` is only available for single rank case");
return reinterpret_cast<dtype_t *>(ptrs[0] + num_bytes * idx);
} }
__device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) { __device__ __forceinline__ dtype_t *buffer_by(int rank_idx, int idx = 0) {
EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case"); EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case");
return reinterpret_cast<dtype_t*>(ptrs[rank_idx] + num_bytes * idx); return reinterpret_cast<dtype_t *>(ptrs[rank_idx] + num_bytes * idx);
} }
}; };
template <typename dtype_t, bool kDecoupled = true> template <typename dtype_t, bool kDecoupled = true> struct SymBuffer {
struct SymBuffer {
private: private:
// NOTES: for non-decoupled case, `recv_ptr` is not used // NOTES: for non-decoupled case, `recv_ptr` is not used
uint8_t* send_ptr; uint8_t *send_ptr;
uint8_t* recv_ptr; uint8_t *recv_ptr;
int num_bytes; int num_bytes;
public: public:
int total_bytes; int total_bytes;
__device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, __device__ __forceinline__ SymBuffer(void *&gbl_ptr, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1) { int sm_id = 0, int num_sms = 1) {
num_bytes = num_elems * sizeof(dtype_t); num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks; int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1); total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);
send_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id; send_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + per_channel_bytes * sm_id;
recv_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); recv_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes; gbl_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
} }
__device__ __forceinline__ dtype_t* send_buffer(int idx = 0) { __device__ __forceinline__ dtype_t *send_buffer(int idx = 0) {
EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case"); EP_STATIC_ASSERT(kDecoupled,
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx); "`send_buffer` is only available for non-decoupled case");
return reinterpret_cast<dtype_t *>(send_ptr + num_bytes * idx);
} }
__device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) { __device__ __forceinline__ dtype_t *recv_buffer(int idx = 0) {
EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case"); EP_STATIC_ASSERT(kDecoupled,
return reinterpret_cast<dtype_t*>(recv_ptr + num_bytes * idx); "`recv_buffer` is only available for non-decoupled case");
return reinterpret_cast<dtype_t *>(recv_ptr + num_bytes * idx);
} }
__device__ __forceinline__ dtype_t* buffer(int idx = 0) { __device__ __forceinline__ dtype_t *buffer(int idx = 0) {
EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case"); EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case");
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx); return reinterpret_cast<dtype_t *>(send_ptr + num_bytes * idx);
} }
}; };
......
#pragma once #pragma once
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp8.h>
#include <hip/hip_runtime.h>
#define NUM_MAX_NVL_PEERS 8 #define NUM_MAX_NVL_PEERS 8
#define NUM_MAX_RDMA_PEERS 20 #define NUM_MAX_RDMA_PEERS 20
#define NUM_MAX_FIFO_SLOTS 32768
#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024) #define NUM_WORKSPACE_BYTES (32 * 1024 * 1024)
#define NUM_MAX_LOCAL_EXPERTS 1024 #define NUM_MAX_LOCAL_EXPERTS 1024
#define NUM_BUFFER_ALIGNMENT_BYTES 128 #define NUM_BUFFER_ALIGNMENT_BYTES 128
#define FINISHED_SUM_TAG 1024 #define FINISHED_SUM_TAG 1024
#define NUM_WAIT_NANOSECONDS 500
#ifndef ENABLE_FAST_DEBUG
#define NUM_CPU_TIMEOUT_SECS 100 #define NUM_CPU_TIMEOUT_SECS 100
#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s #define NUM_TIMEOUT_CYCLES 200000000000ll // 200G cycles ~= 100s
#else
#define NUM_CPU_TIMEOUT_SECS 10 #define NUM_WAIT_NANOSECONDS 500
#define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s
#endif #define NUM_WAIT_CYCLES_TIMES_64 16
#define LOW_LATENCY_SEND_PHASE 1 #define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2 #define LOW_LATENCY_RECV_PHASE 2
// Make CLion CUDA indexing work #define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#ifdef __CLION_IDE__
#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier) #define DEFAULT_NUM_CU 20
#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) #define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
#endif #define DEFAULT_NUM_MAX_XGMI_CHUNKED_RECV_TOKENS 256
#define DEFAULT_NUM_MAX_RDMA_CHUNKED_SEND_TOKENS 6
#define DEFAULT_NUM_MAX_RDMA_CHUNKED_RECV_TOKENS 256
static constexpr int32_t kWarpSize = 64;
// For ROCm equals to half the wave size or Nvidia warp size
static constexpr int32_t kEmulatedWarpSize = kWarpSize / 2;
static constexpr uint64_t kFullWarpMask = 0xffffffffffffffff;
static constexpr uint64_t kFirstHalfMask = 0x00000000ffffffff;
static constexpr uint64_t kSecondHalfMask = 0xffffffff00000000;
template <typename T> constexpr inline __host__ __device__ T DIVUP(const T &x, const T &y) {
return (((x) + ((y) -1)) / (y));
}
template <typename T> inline __host__ __device__ T ALIGN(T a, T b) {
return DIVUP<T>(a, b) * b;
}
// Remove Torch restrictions // Remove Torch restrictions
#ifdef __CUDA_NO_HALF_CONVERSIONS__ #ifdef __CUDA_NO_HALF_CONVERSIONS__
...@@ -43,39 +64,7 @@ ...@@ -43,39 +64,7 @@
#undef __CUDA_NO_BFLOAT162_OPERATORS__ #undef __CUDA_NO_BFLOAT162_OPERATORS__
#endif #endif
#include <cstdint> // Remove Torch restrictions for HIP
#include <cuda_bf16.h> #ifdef __HIP_NO_HALF_OPERATORS__
#include <cuda_runtime.h> #undef __HIP_NO_HALF_OPERATORS__
#ifndef DISABLE_SM90_FEATURES
#include <cuda_fp8.h>
#else
// Ampere does not support FP8 features
#define __NV_E4M3 0
#define __NV_E5M2 1
typedef int __nv_fp8_interpretation_t;
typedef int __nv_fp8x4_e4m3;
typedef uint8_t __nv_fp8_storage_t;
#endif
namespace deep_ep {
#ifndef TOPK_IDX_BITS
#define TOPK_IDX_BITS 64
#endif
#define INT_BITS_T2(bits) int##bits##_t
#define INT_BITS_T(bits) INT_BITS_T2(bits)
typedef INT_BITS_T(TOPK_IDX_BITS) topk_idx_t; // int32_t or int64_t
#undef INT_BITS_T
#undef INT_BITS_T2
} // namespace deep_ep
#ifndef DISABLE_NVSHMEM
#include <nvshmem.h>
#include <nvshmemx.h>
#include <infiniband/mlx5dv.h>
#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
#include <device_host_transport/nvshmem_common_ibgda.h>
#endif #endif
...@@ -24,9 +24,9 @@ public: ...@@ -24,9 +24,9 @@ public:
#ifndef CUDA_CHECK #ifndef CUDA_CHECK
#define CUDA_CHECK(cmd) \ #define CUDA_CHECK(cmd) \
do { \ do { \
cudaError_t e = (cmd); \ hipError_t e = (cmd); \
if (e != cudaSuccess) { \ if (e != hipSuccess) { \
throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \ throw EPException("CUDA", __FILE__, __LINE__, hipGetErrorString(e)); \
} \ } \
} while (0) } while (0)
#endif #endif
...@@ -45,7 +45,7 @@ do { \ ...@@ -45,7 +45,7 @@ do { \
do { \ do { \
if (not (cond)) { \ if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \ __builtin_trap(); \
} \ } \
} while (0) } while (0)
#endif #endif
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment