Commit 5563b6d0 authored by lijian6's avatar lijian6
Browse files

Fitter for DCU.


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent da6ca24e
export OMPI_MCA_pml=ucx
export OMPI_MCA_osc=ucx
export OMPI_MCA_coll_hcoll_enable=0
export UCX_TLS=rc,rocm
# export ROCSHMEM_UNIQUEID_WITH_MPI=1
export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export ROCSHMEM_MAX_NUM_CONTEXTS=32
export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_4:1,mlx5_6:1,mlx5_8:1
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCSHMEM_HEAP_SIZE=10737418240
export PYTHONPATH=/work/Tmp/DeepEP:$PYTHONPATH
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/internode_lj.py
export OMPI_MCA_pml=ucx
export OMPI_MCA_osc=ucx
export OMPI_MCA_coll_hcoll_enable=0
export UCX_TLS=rc,rocm
# export ROCSHMEM_UNIQUEID_WITH_MPI=1
export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export ROCSHMEM_MAX_NUM_CONTEXTS=32
export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_4:1,mlx5_6:1,mlx5_8:1
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCSHMEM_HEAP_SIZE=10737418240
export PYTHONPATH=/work/Tmp/DeepEP:$PYTHONPATH
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/internode_lj.py
#!/bin/bash
set -eux
# if [ ! -d "build_" ]; then
# mkdir -p build_
# fi
# PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])")
# PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])")
/opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/intranode.hip -o build_/intranode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/runtime.hip -o build_/runtime.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/layout.cu -o build_/layout.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/deep_ep.hip -o build_/deep_ep.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/internode.hip -o build_/internode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o -L/work/Tmp/DeepEP/rocshmem_dir/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/work/Tmp/DeepEP/rocshmem_dir/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so.5.2.25211.1469-8d6b0397 /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L/usr/local/lib/python3.10/dist-packages/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/lib/ -libverbs -lmlx5
# build whl
echo "Using Python: $(which python3)"
python3 --version
python setup.py bdist_wheel
echo "✅ Build complete:"
ls -lh dist/
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/intranode.hip -o build_/intranode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/runtime.hip -o build_/runtime.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/layout.cu -o build_/layout.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/deep_ep.hip -o build_/deep_ep.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# /opt/dtk/bin/hipcc -Icsrc/ -I./rocshmem_dir/include/ -I/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10/ -c -c ./csrc/kernels/internode.hip -o build_/internode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
# hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o -L/work/Tmp/DeepEP/rocshmem_dir/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o aaa.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,aaa.so -Wl,-rpath,/work/Tmp/DeepEP/rocshmem_dir/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so.5.2.25211.1469-8d6b0397 /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L/usr/local/lib/python3.10/dist-packages/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/public/home/lishen/Code/rocSHMEM/3rd_party/install_dtk25.04.1/ompi/lib/ -libverbs -lmlx5
#pragma once
#include "kernels/api.cuh"
#include "./kernels/api.cuh"
#include "./kernels/configs.cuh"
#include "kernels/exception.cuh"
namespace deep_ep {
template <typename dtype_t>
dtype_t ceil_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;
}
template <typename dtype_t>
dtype_t align_up(dtype_t a, dtype_t b) {
return ceil_div<dtype_t>(a, b) * b;
}
template <typename dtype_t>
dtype_t align_down(dtype_t a, dtype_t b) {
return a / b * b;
}
struct Config {
int num_sms;
int num_max_nvl_chunked_send_tokens;
......@@ -27,77 +13,91 @@ struct Config {
int num_max_rdma_chunked_send_tokens;
int num_max_rdma_chunked_recv_tokens;
Config(int num_sms,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) :
num_sms(num_sms),
num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),
num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),
num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
Config(int num_sms, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens)
: num_sms(num_sms), num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),
num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),
num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
EP_HOST_ASSERT(num_sms >= 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and
num_max_nvl_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and
num_max_rdma_chunked_recv_tokens > 0);
// Ceil up RDMA buffer size
this->num_max_rdma_chunked_recv_tokens = align_up<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
this->num_max_rdma_chunked_recv_tokens =
ALIGN<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
// NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2);
// NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always
// have space to push
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <=
num_max_rdma_chunked_recv_tokens / 2);
}
size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / 2;
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
#ifndef DISABLE_NVSHMEM
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes();
#ifndef DISABLE_ROCSHMEM
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
internode::get_source_meta_bytes();
#endif
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
sizeof(int64_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
kNumMaxScales * sizeof(float);
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}
size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
#ifndef DISABLE_NVSHMEM
#ifndef DISABLE_ROCSHMEM
// Legacy mode
if (num_ranks <= NUM_MAX_NVL_PEERS)
return 0;
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % 2 == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms / 2;
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
num_bytes +=
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxScales * sizeof(float) * 2;
num_bytes +=
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
#else
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install "
"rocSHMEM by following docs/install_dependencies.md");
#endif
}
};
......@@ -105,33 +105,35 @@ struct Config {
struct LowLatencyBuffer {
int num_clean_int = 0;
void* dispatch_rdma_send_buffer = nullptr;
void* dispatch_rdma_recv_data_buffer = nullptr;
int* dispatch_rdma_recv_count_buffer = nullptr;
void *dispatch_rdma_send_buffer = nullptr;
void *dispatch_rdma_recv_data_buffer = nullptr;
int *dispatch_rdma_recv_count_buffer = nullptr;
void* combine_rdma_send_buffer = nullptr;
void* combine_rdma_recv_data_buffer = nullptr;
int* combine_rdma_recv_flag_buffer = nullptr;
void *combine_rdma_send_buffer = nullptr;
void *combine_rdma_recv_data_buffer = nullptr;
int *combine_rdma_recv_flag_buffer = nullptr;
void* combine_rdma_send_buffer_data_start = nullptr;
size_t num_bytes_per_combine_msg = 0;
void *combine_rdma_send_buffer_data_start = nullptr;
size_t num_bytes_per_combine_msg = 0;
std::pair<int*, int> clean_meta() {
std::pair<int *, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int};
}
};
struct LowLatencyLayout {
size_t total_bytes = 0;
size_t total_bytes = 0;
LowLatencyBuffer buffers[2];
template <typename out_ptr_t = void*, typename count_ptr_t = uint8_t*, typename in_ptr_t = void*>
out_ptr_t advance(const in_ptr_t& ptr, size_t count) {
template <typename out_ptr_t = void *, typename count_ptr_t = uint8_t *,
typename in_ptr_t = void *>
out_ptr_t advance(const in_ptr_t &ptr, size_t count) {
return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
}
LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) {
const int num_scales = hidden / 128;
// Dispatch and combine layout:
......@@ -140,56 +142,69 @@ struct LowLatencyLayout {
// - 2 symmetric odd/even signaling buffers
// Message sizes
// NOTES: you should add a control `int4` for combine messages if you want to do data transformation
// NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-128-channel min/max
EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);
size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float));
size_t num_bytes_per_combine_msg = num_scales * sizeof(nv_bfloat162) + hidden * sizeof(nv_bfloat16);
// NOTES: you should add a control `int4` for combine messages if you want to do data
// transformation
EP_HOST_ASSERT(num_scales * sizeof(float) <= static_cast<size_t>(hidden));
size_t num_bytes_per_dispatch_msg =
sizeof(int4) +
std::max(hidden * sizeof(hip_bfloat16), hidden + num_scales * sizeof(float));
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16);
// Send buffer
size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t dispatch_send_buffer_bytes =
num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_send_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);
total_bytes += send_buffer_bytes * 2;
// Symmetric receive buffers
// TODO: optimize memory usages
size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
size_t dispatch_recv_data_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_recv_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t recv_buffer_bytes =
std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
total_bytes += recv_buffer_bytes * 2;
// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = align_up<size_t>(signaling_buffer_bytes, 128);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes =
std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
total_bytes += signaling_buffer_bytes_aligned * 2;
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// so you may see some parameters are duplicated
for (int i = 0; i < 2; ++ i) {
for (int i = 0; i < 2; ++i) {
buffers[i] = {
static_cast<int>(signaling_buffer_bytes / sizeof(int)),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
num_bytes_per_combine_msg
};
num_bytes_per_combine_msg};
}
}
};
size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES;
inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) {
auto num_bytes =
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
.total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
NUM_BUFFER_ALIGNMENT_BYTES;
}
} // namespace deep_ep
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#pragma once
#include "kernels/api.cuh"
#include "kernels/configs.cuh"
#include "kernels/exception.cuh"
namespace deep_ep {
struct Config {
int num_sms;
int num_max_nvl_chunked_send_tokens;
int num_max_nvl_chunked_recv_tokens;
int num_max_rdma_chunked_send_tokens;
int num_max_rdma_chunked_recv_tokens;
Config(int num_sms, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens)
: num_sms(num_sms), num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),
num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),
num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
EP_HOST_ASSERT(num_sms >= 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and
num_max_nvl_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and
num_max_rdma_chunked_recv_tokens > 0);
// Ceil up RDMA buffer size
this->num_max_rdma_chunked_recv_tokens =
ALIGN<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
// NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always
// have space to push
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <=
num_max_rdma_chunked_recv_tokens / 2);
}
size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
#ifndef DISABLE_ROCSHMEM
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
internode::get_source_meta_bytes();
#endif
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
sizeof(int64_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK *
sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens *
kNumMaxScales * sizeof(float);
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}
size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
#ifndef DISABLE_ROCSHMEM
// Legacy mode
if (num_ranks <= NUM_MAX_NVL_PEERS)
return 0;
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % 2 == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes +=
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens *
kNumMaxScales * sizeof(float) * 2;
num_bytes +=
num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
#else
EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install "
"rocSHMEM by following docs/install_dependencies.md");
#endif
}
};
struct LowLatencyBuffer {
int num_clean_int = 0;
void *dispatch_rdma_send_buffer = nullptr;
void *dispatch_rdma_recv_data_buffer = nullptr;
int *dispatch_rdma_recv_count_buffer = nullptr;
void *combine_rdma_send_buffer = nullptr;
void *combine_rdma_recv_data_buffer = nullptr;
int *combine_rdma_recv_flag_buffer = nullptr;
void *combine_rdma_send_buffer_data_start = nullptr;
size_t num_bytes_per_combine_msg = 0;
std::pair<int *, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int};
}
};
struct LowLatencyLayout {
size_t total_bytes = 0;
LowLatencyBuffer buffers[2];
template <typename out_ptr_t = void *, typename count_ptr_t = uint8_t *,
typename in_ptr_t = void *>
out_ptr_t advance(const in_ptr_t &ptr, size_t count) {
return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
}
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) {
const int num_scales = hidden / 128;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even receive buffers
// - 2 symmetric odd/even signaling buffers
// Message sizes
// NOTES: you should add a control `int4` for combine messages if you want to do data
// transformation
EP_HOST_ASSERT(num_scales * sizeof(float) <= static_cast<size_t>(hidden));
size_t num_bytes_per_dispatch_msg =
sizeof(int4) +
std::max(hidden * sizeof(hip_bfloat16), hidden + num_scales * sizeof(float));
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16);
// Send buffer
size_t dispatch_send_buffer_bytes =
num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_send_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);
total_bytes += send_buffer_bytes * 2;
// Symmetric receive buffers
// TODO: optimize memory usages
size_t dispatch_recv_data_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_recv_buffer_bytes =
num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t recv_buffer_bytes =
std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
total_bytes += recv_buffer_bytes * 2;
// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes =
std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
total_bytes += signaling_buffer_bytes_aligned * 2;
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// so you may see some parameters are duplicated
for (int i = 0; i < 2; ++i) {
buffers[i] = {
static_cast<int>(signaling_buffer_bytes / sizeof(int)),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
num_bytes_per_combine_msg};
}
}
};
inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) {
auto num_bytes =
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
.total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
NUM_BUFFER_ALIGNMENT_BYTES;
}
} // namespace deep_ep
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#pragma once
// Forcibly disable NDEBUG
#ifdef NDEBUG
#undef NDEBUG
#endif
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/types.h>
#include <tuple>
#include <vector>
#include "./kernels/configs.cuh"
#include "kernels/exception.cuh"
#include "config.hpp"
#include "event.hpp"
#include "kernels/configs.cuh"
#include "kernels/exception.cuh"
#ifndef TORCH_EXTENSION_NAME
#define TORCH_EXTENSION_NAME deep_ep_cpp
#endif
namespace deep_ep {
......@@ -27,27 +18,27 @@ struct Buffer {
private:
// Low-latency mode buffer
int low_latency_buffer_idx = 0;
bool low_latency_mode = false;
int low_latency_buffer_idx = 0;
bool low_latency_mode = false;
// NVLink Buffer
int64_t num_nvl_bytes;
void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void** buffer_ptrs_gpu = nullptr;
void *buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void **buffer_ptrs_gpu = nullptr;
// NVSHMEM Buffer
int64_t num_rdma_bytes;
void* rdma_buffer_ptr = nullptr;
void *rdma_buffer_ptr = nullptr;
// Device info and communication
int device_id;
int num_device_sms;
int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks;
cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
int device_id;
int num_device_sms;
int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks;
hipIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
// Stream for communication
at::cuda::CUDAStream comm_stream;
at::hip::HIPStreamMasqueradingAsCUDA comm_stream;
// After IPC/NVSHMEM synchronization, this flag will be true
bool available = false;
......@@ -58,26 +49,29 @@ private:
bool destroyed = false;
// Barrier signals
int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int** barrier_signal_ptrs_gpu = nullptr;
int *barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int **barrier_signal_ptrs_gpu = nullptr;
// Workspace
void* workspace = nullptr;
void *workspace = nullptr;
// Host-side MoE info
volatile int* moe_recv_counter = nullptr;
int* moe_recv_counter_mapped = nullptr;
volatile int *moe_recv_counter = nullptr;
int *moe_recv_counter_mapped = nullptr;
// Host-side expert-level MoE info
volatile int* moe_recv_expert_counter = nullptr;
int* moe_recv_expert_counter_mapped = nullptr;
volatile int *moe_recv_expert_counter = nullptr;
int *moe_recv_expert_counter_mapped = nullptr;
// Host-side RDMA-level MoE info
volatile int* moe_recv_rdma_counter = nullptr;
int* moe_recv_rdma_counter_mapped = nullptr;
volatile int *moe_recv_rdma_counter = nullptr;
int *moe_recv_rdma_counter_mapped = nullptr;
bool use_default_stream_as_comm_stream = false;
public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy);
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool use_default_stream_as_comm_stream);
~Buffer() noexcept(false);
......@@ -97,70 +91,102 @@ public:
pybind11::bytearray get_local_nvshmem_unique_id() const;
torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const;
torch::Tensor get_local_buffer_tensor(const pybind11::object &dtype, int64_t offset,
bool use_rdma_buffer) const;
torch::Stream get_comm_stream() const;
void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt);
void sync(const std::vector<int> &device_ids,
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
const std::optional<pybind11::bytearray> &root_unique_id_opt);
void destroy();
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event,
bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
int expert_alignment, int num_worst_tokens, const Config& config,
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
std::optional<EventHandle>>
get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::optional<EventHandle>>
intranode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx,
const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank,
const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert,
int cached_num_recv_tokens,
const std::optional<torch::Tensor> &cached_rank_prefix_matrix,
const std::optional<torch::Tensor> &cached_channel_prefix_matrix,
int expert_alignment, int num_worst_tokens, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const std::optional<torch::Tensor>& num_tokens_per_rdma_rank,
const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
intranode_combine(const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &bias_0,
const std::optional<torch::Tensor> &bias_1, const torch::Tensor &src_idx,
const torch::Tensor &rank_prefix_matrix,
const torch::Tensor &channel_prefix_matrix, const torch::Tensor &send_head,
const Config &config, std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>,
torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx,
const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank,
const std::optional<torch::Tensor> &num_tokens_per_rdma_rank,
const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
const std::optional<torch::Tensor> &cached_rdma_channel_prefix_matrix,
const std::optional<torch::Tensor> &cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor> &cached_gbl_channel_prefix_matrix,
const std::optional<torch::Tensor> &cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& bias_0, const std::optional<torch::Tensor>& bias_1,
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook);
internode_combine(
const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &bias_0, const std::optional<torch::Tensor> &bias_1,
const torch::Tensor &src_meta, const torch::Tensor &is_combined_token_in_rank,
const torch::Tensor &rdma_channel_prefix_matrix, const torch::Tensor &rdma_rank_prefix_sum,
const torch::Tensor &gbl_channel_prefix_matrix, const torch::Tensor &combined_rdma_head,
const torch::Tensor &combined_nvl_head, const Config &config,
std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream);
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden,
int num_experts);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_idx,
const std::optional<torch::Tensor> &cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor> &dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8,
bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);
torch::Tensor
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const;
low_latency_combine(const torch::Tensor &x, const torch::Tensor &topk_idx,
const torch::Tensor &topk_weights, const torch::Tensor &src_info,
const torch::Tensor &layout_range,
const std::optional<torch::Tensor> &combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor> &out = std::nullopt);
torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank,
int hidden, int num_experts) const;
};
} // namespace deep_ep
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/types.h>
#include <tuple>
#include <vector>
#include "kernels/configs.cuh"
#include "kernels/exception.cuh"
#include "config_hip.hpp"
#include "event.hpp"
namespace deep_ep {
struct Buffer {
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8");
private:
// Low-latency mode buffer
int low_latency_buffer_idx = 0;
bool low_latency_mode = false;
// NVLink Buffer
int64_t num_nvl_bytes;
void *buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void **buffer_ptrs_gpu = nullptr;
// NVSHMEM Buffer
int64_t num_rdma_bytes;
void *rdma_buffer_ptr = nullptr;
// Device info and communication
int device_id;
int num_device_sms;
int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks;
hipIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
// Stream for communication
at::hip::HIPStreamMasqueradingAsCUDA comm_stream;
// After IPC/NVSHMEM synchronization, this flag will be true
bool available = false;
// Whether explicit `destroy()` is required.
bool explicitly_destroy;
// After `destroy()` be called, this flag will be true
bool destroyed = false;
// Barrier signals
int *barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int **barrier_signal_ptrs_gpu = nullptr;
// Workspace
void *workspace = nullptr;
// Host-side MoE info
volatile int *moe_recv_counter = nullptr;
int *moe_recv_counter_mapped = nullptr;
// Host-side expert-level MoE info
volatile int *moe_recv_expert_counter = nullptr;
int *moe_recv_expert_counter_mapped = nullptr;
// Host-side RDMA-level MoE info
volatile int *moe_recv_rdma_counter = nullptr;
int *moe_recv_rdma_counter_mapped = nullptr;
bool use_default_stream_as_comm_stream = false;
public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool use_default_stream_as_comm_stream);
~Buffer() noexcept(false);
bool is_available() const;
bool is_internode_available() const;
int get_num_rdma_ranks() const;
int get_rdma_rank() const;
int get_root_rdma_rank(bool global) const;
int get_local_device_id() const;
pybind11::bytearray get_local_ipc_handle() const;
pybind11::bytearray get_local_nvshmem_unique_id() const;
torch::Tensor get_local_buffer_tensor(const pybind11::object &dtype, int64_t offset,
bool use_rdma_buffer) const;
torch::Stream get_comm_stream() const;
void sync(const std::vector<int> &device_ids,
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
const std::optional<pybind11::bytearray> &root_unique_id_opt);
void destroy();
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
std::optional<EventHandle>>
get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std::optional<EventHandle>>
intranode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx,
const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank,
const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert,
int cached_num_recv_tokens,
const std::optional<torch::Tensor> &cached_rank_prefix_matrix,
const std::optional<torch::Tensor> &cached_channel_prefix_matrix,
int expert_alignment, int num_worst_tokens, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
intranode_combine(const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &bias_0,
const std::optional<torch::Tensor> &bias_1, const torch::Tensor &src_idx,
const torch::Tensor &rank_prefix_matrix,
const torch::Tensor &channel_prefix_matrix, const torch::Tensor &send_head,
const Config &config, std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor,
std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>,
torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_dispatch(const torch::Tensor &x, const std::optional<torch::Tensor> &x_scales,
const std::optional<torch::Tensor> &topk_idx,
const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &num_tokens_per_rank,
const std::optional<torch::Tensor> &num_tokens_per_rdma_rank,
const torch::Tensor &is_token_in_rank,
const std::optional<torch::Tensor> &num_tokens_per_expert,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor> &cached_rdma_channel_prefix_matrix,
const std::optional<torch::Tensor> &cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor> &cached_gbl_channel_prefix_matrix,
const std::optional<torch::Tensor> &cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config &config,
std::optional<EventHandle> &previous_event, bool async,
bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_combine(
const torch::Tensor &x, const std::optional<torch::Tensor> &topk_weights,
const std::optional<torch::Tensor> &bias_0, const std::optional<torch::Tensor> &bias_1,
const torch::Tensor &src_meta, const torch::Tensor &is_combined_token_in_rank,
const torch::Tensor &rdma_channel_prefix_matrix, const torch::Tensor &rdma_rank_prefix_sum,
const torch::Tensor &gbl_channel_prefix_matrix, const torch::Tensor &combined_rdma_head,
const torch::Tensor &combined_nvl_head, const Config &config,
std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream);
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden,
int num_experts);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor,
torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_idx,
const std::optional<torch::Tensor> &cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor> &dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8,
bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor &x, const torch::Tensor &topk_idx,
const torch::Tensor &topk_weights, const torch::Tensor &src_info,
const torch::Tensor &layout_range,
const std::optional<torch::Tensor> &combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor> &out = std::nullopt);
torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank,
int hidden, int num_experts) const;
};
} // namespace deep_ep
#include <ATen/cuda/CUDAContext.h>
#include <memory>
#pragma once
#include <ATen/hip/HIPContext.h>
#include "kernels/exception.cuh"
namespace deep_ep {
......@@ -10,33 +10,34 @@ struct EventHandle {
EventHandle() {
event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(at::cuda::getCurrentCUDAStream());
event->record(at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
explicit EventHandle(const at::cuda::CUDAStream& stream) {
explicit EventHandle(const at::hip::HIPStreamMasqueradingAsCUDA &stream) {
event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(stream);
}
EventHandle(const EventHandle& other) = default;
EventHandle(const EventHandle &other) = default;
void current_stream_wait() const {
at::cuda::getCurrentCUDAStream().unwrap().wait(*event);
at::hip::getCurrentHIPStreamMasqueradingAsCUDA().unwrap().wait(*event);
}
};
torch::Event create_event(const at::cuda::CUDAStream &s) {
inline torch::Event create_event(const at::hip::HIPStreamMasqueradingAsCUDA &s) {
auto event = torch::Event(torch::kCUDA);
event.record(s);
return event;
}
void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) {
inline void stream_wait(const at::hip::HIPStreamMasqueradingAsCUDA &s_0,
const at::hip::HIPStreamMasqueradingAsCUDA &s_1) {
EP_HOST_ASSERT(s_0.id() != s_1.id());
s_0.unwrap().wait(create_event(s_1));
}
void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) {
inline void stream_wait(const at::hip::HIPStreamMasqueradingAsCUDA &s, const EventHandle &event) {
s.unwrap().wait(*event.event);
}
......
......@@ -15,7 +15,6 @@ add_deep_ep_library(runtime_cuda runtime.cu)
add_deep_ep_library(layout_cuda layout.cu)
add_deep_ep_library(intranode_cuda intranode.cu)
add_deep_ep_library(internode_cuda internode.cu)
add_deep_ep_library(internode_ll_cuda internode_ll.cu)
# Later, we should link all libraries in `EP_CUDA_LIBRARIES`
set(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda internode_ll_cuda PARENT_SCOPE)
set(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda PARENT_SCOPE)
#pragma once
#include <hip/hip_runtime.h>
#include <vector>
#include "configs.cuh"
......@@ -9,7 +10,7 @@ namespace deep_ep {
// Intranode runtime
namespace intranode {
void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream);
void barrier(int **barrier_signal_ptrs, int rank, int num_ranks, hipStream_t stream);
} // namespace intranode
......@@ -18,7 +19,8 @@ namespace internode {
std::vector<uint8_t> get_unique_id();
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode);
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks,
bool low_latency_mode);
void *alloc(size_t size, size_t alignment);
......@@ -33,49 +35,46 @@ void finalize();
// Layout kernels
namespace layout {
void get_dispatch_layout(const topk_idx_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts,
cudaStream_t stream);
void get_dispatch_layout(const int64_t *topk_idx, int *num_tokens_per_rank,
int *num_tokens_per_rdma_rank, int *num_tokens_per_expert,
bool *is_token_in_rank, int num_tokens, int num_topk, int num_ranks,
int num_experts, hipStream_t stream);
} // namespace layout
// Intranode kernels
namespace intranode {
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank,
cudaStream_t stream, int num_sms);
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks,
cudaStream_t stream);
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, topk_idx_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens);
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream);
void combine(cudaDataType_t type,
void* recv_x, float* recv_topk_weights,
const void* x, const float* topk_weights,
const void* bias_0, const void* bias_1,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens);
void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
int64_t *moe_num_recv_tokens_per_experts, int num_experts, int num_tokens,
const bool *is_token_in_rank, int *channel_prefix_matrix,
int *rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void **buffer_ptrs, int **barrier_signal_ptrs, int rank, hipStream_t stream,
int num_sms);
void cached_notify_dispatch(const int *rank_prefix_matrix, int num_memset_int, void **buffer_ptrs,
int **barrier_signal_ptrs, int rank, int num_ranks, hipStream_t stream);
void dispatch(void *recv_x, float *recv_x_scales, int *recv_src_idx, int64_t *recv_topk_idx,
float *recv_topk_weights, int *recv_channel_offset, int *send_head, const void *x,
const float *x_scales, const int64_t *topk_idx, const float *topk_weights,
const bool *is_token_in_rank, const int *channel_prefix_matrix, int num_tokens,
int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int scale_token_stride, int scale_hidden_stride, void **buffer_ptrs, int rank,
int num_ranks, hipStream_t stream, int num_sms, int num_max_send_tokens,
int num_recv_buffer_tokens);
void cached_notify_combine(void **buffer_ptrs, int *send_head, int num_channels,
int num_recv_tokens, int num_memset_int, int **barrier_signal_ptrs,
int rank, int num_ranks, hipStream_t stream);
void combine(hipDataType type, void *recv_x, float *recv_topk_weights, const void *x,
const float *topk_weights, const void *bias_0, const void *bias_1, const int *src_idx,
const int *rank_prefix_matrix, const int *channel_prefix_matrix, int *send_head,
int num_tokens, int num_recv_tokens, int hidden, int num_topk, void **buffer_ptrs,
int rank, int num_ranks, hipStream_t stream, int num_sms, int num_max_send_tokens,
int num_recv_buffer_tokens);
} // namespace intranode
......@@ -84,89 +83,52 @@ namespace internode {
int get_source_meta_bytes();
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
const bool* is_token_in_rank, int num_tokens, int num_channels,
int hidden_int4, int num_scales, int num_topk, int expert_alignment,
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** barrier_signal_ptrs, int rank,
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
int num_experts, const bool *is_token_in_rank, int num_tokens,
int num_channels, int hidden_int4, int num_scales, int num_topk,
int expert_alignment, int *rdma_channel_prefix_matrix,
int *recv_rdma_rank_prefix_sum, int *gbl_channel_prefix_matrix,
int *recv_gbl_rank_prefix_sum, void *rdma_buffer_ptr,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool low_latency_mode);
void dispatch(void* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights,
int* send_rdma_head, int* send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
const bool* is_token_in_rank,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
int scale_token_stride, int scale_hidden_stride,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, bool is_cached_dispatch,
cudaStream_t stream, int num_channels, bool low_latency_mode);
void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
void *recv_src_meta, const void *x, const float *x_scales, const int64_t *topk_idx,
const float *topk_weights, int *send_rdma_head, int *send_nvl_head,
int *recv_rdma_channel_prefix_matrix, int *recv_gbl_channel_prefix_matrix,
const int *rdma_channel_prefix_matrix, const int *recv_rdma_rank_prefix_sum,
const int *gbl_channel_prefix_matrix, const int *recv_gbl_rank_prefix_sum,
const bool *is_token_in_rank, int num_tokens, int hidden_int4, int num_scales,
int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride,
void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels,
bool low_latency_mode);
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head,
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** barrier_signal_ptrs, int rank, cudaStream_t stream,
int64_t num_rdma_bytes, int64_t num_nvl_bytes,
int num_ranks, int num_channels, int num_combined_tokens,
int *combined_rdma_head, const int *rdma_channel_prefix_matrix,
const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode);
void combine(cudaDataType_t type,
void* combined_x, float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const void* x, const float* topk_weights,
const void* bias_0, const void* bias_1,
const int* combined_rdma_head, const int* combined_nvl_head,
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode);
void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
const bool *is_combined_token_in_rank, const void *x, const float *topk_weights,
const void *bias_0, const void *bias_1, const int *combined_rdma_head,
const int *combined_nvl_head, const void *src_meta,
const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode);
} // namespace internode
// Internode low-latency kernels
namespace internode_ll {
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
cudaStream_t stream);
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const topk_idx_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases);
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const topk_idx_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases, bool zero_copy);
} // namespace internode_ll
} // namespace deep_ep
......@@ -5,133 +5,131 @@
namespace deep_ep {
template <typename dtype_t>
struct Buffer {
template <typename dtype_t> struct Buffer {
private:
uint8_t* ptr;
uint8_t *ptr;
public:
int total_bytes;
__device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {}
__device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) {
__device__ __forceinline__ Buffer(void *&gbl_ptr, int num_elems, int offset = 0) {
total_bytes = num_elems * sizeof(dtype_t);
ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + offset * sizeof(dtype_t);
gbl_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
__device__ __forceinline__ Buffer advance_also(void *&gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
return *this;
}
__device__ __forceinline__ dtype_t* buffer() {
return reinterpret_cast<dtype_t*>(ptr);
}
__device__ __forceinline__ dtype_t *buffer() { return reinterpret_cast<dtype_t *>(ptr); }
__device__ __forceinline__ dtype_t& operator[](int idx) {
return buffer()[idx];
}
__device__ __forceinline__ dtype_t &operator[](int idx) { return buffer()[idx]; }
};
template <typename dtype_t, int kNumRanks = 1>
struct AsymBuffer {
template <typename dtype_t, int kNumRanks = 1> struct AsymBuffer {
private:
uint8_t* ptrs[kNumRanks];
int num_bytes;
uint8_t *ptrs[kNumRanks];
int num_bytes;
public:
int total_bytes;
__device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
__device__ __forceinline__ AsymBuffer(void *&gbl_ptr, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1, int offset = 0) {
EP_STATIC_ASSERT(kNumRanks == 1, "");
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
ptrs[0] = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
total_bytes = per_channel_bytes * num_sms;
ptrs[0] =
reinterpret_cast<uint8_t *>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks,
__device__ __forceinline__ AsymBuffer(void **gbl_ptrs, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1, int offset = 0) {
EP_STATIC_ASSERT(kNumRanks > 1, "");
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
for (int i = 0; i < kNumRanks; ++ i) {
ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
total_bytes = per_channel_bytes * num_sms;
for (int i = 0; i < kNumRanks; ++i) {
ptrs[i] = reinterpret_cast<uint8_t *>(gbl_ptrs[i]) + per_channel_bytes * sm_id +
num_bytes * offset;
gbl_ptrs[i] = reinterpret_cast<uint8_t *>(gbl_ptrs[i]) + total_bytes;
}
}
__device__ __forceinline__ void advance(int shift) {
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i)
#pragma unroll
for (int i = 0; i < kNumRanks; ++i)
ptrs[i] = ptrs[i] + shift * sizeof(dtype_t);
}
__device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
__device__ __forceinline__ AsymBuffer advance_also(void *&gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
return *this;
}
template<int kNumAlsoRanks>
__device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) {
for (int i = 0; i < kNumAlsoRanks; ++ i)
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
template <int kNumAlsoRanks>
__device__ __forceinline__ AsymBuffer advance_also(void **gbl_ptrs) {
for (int i = 0; i < kNumAlsoRanks; ++i)
gbl_ptrs[i] = reinterpret_cast<uint8_t *>(gbl_ptrs[i]) + total_bytes;
return *this;
}
__device__ __forceinline__ dtype_t* buffer(int idx = 0) {
EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case");
return reinterpret_cast<dtype_t*>(ptrs[0] + num_bytes * idx);
__device__ __forceinline__ dtype_t *buffer(int idx = 0) {
EP_STATIC_ASSERT(kNumRanks == 1,
"`buffer` is only available for single rank case");
return reinterpret_cast<dtype_t *>(ptrs[0] + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) {
__device__ __forceinline__ dtype_t *buffer_by(int rank_idx, int idx = 0) {
EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case");
return reinterpret_cast<dtype_t*>(ptrs[rank_idx] + num_bytes * idx);
return reinterpret_cast<dtype_t *>(ptrs[rank_idx] + num_bytes * idx);
}
};
template <typename dtype_t, bool kDecoupled = true>
struct SymBuffer {
template <typename dtype_t, bool kDecoupled = true> struct SymBuffer {
private:
// NOTES: for non-decoupled case, `recv_ptr` is not used
uint8_t* send_ptr;
uint8_t* recv_ptr;
int num_bytes;
uint8_t *send_ptr;
uint8_t *recv_ptr;
int num_bytes;
public:
int total_bytes;
__device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
__device__ __forceinline__ SymBuffer(void *&gbl_ptr, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1) {
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);
send_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;
recv_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);
send_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + per_channel_bytes * sm_id;
recv_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
gbl_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ dtype_t* send_buffer(int idx = 0) {
EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case");
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
__device__ __forceinline__ dtype_t *send_buffer(int idx = 0) {
EP_STATIC_ASSERT(kDecoupled,
"`send_buffer` is only available for non-decoupled case");
return reinterpret_cast<dtype_t *>(send_ptr + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) {
EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case");
return reinterpret_cast<dtype_t*>(recv_ptr + num_bytes * idx);
__device__ __forceinline__ dtype_t *recv_buffer(int idx = 0) {
EP_STATIC_ASSERT(kDecoupled,
"`recv_buffer` is only available for non-decoupled case");
return reinterpret_cast<dtype_t *>(recv_ptr + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* buffer(int idx = 0) {
__device__ __forceinline__ dtype_t *buffer(int idx = 0) {
EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case");
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
return reinterpret_cast<dtype_t *>(send_ptr + num_bytes * idx);
}
};
......
#pragma once
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp8.h>
#include <hip/hip_runtime.h>
#define NUM_MAX_NVL_PEERS 8
#define NUM_MAX_RDMA_PEERS 20
#define NUM_MAX_FIFO_SLOTS 32768
#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024)
#define NUM_MAX_LOCAL_EXPERTS 1024
#define NUM_BUFFER_ALIGNMENT_BYTES 128
#define FINISHED_SUM_TAG 1024
#define NUM_WAIT_NANOSECONDS 500
#ifndef ENABLE_FAST_DEBUG
#define NUM_CPU_TIMEOUT_SECS 100
#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s
#else
#define NUM_CPU_TIMEOUT_SECS 10
#define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s
#endif
#define NUM_TIMEOUT_CYCLES 200000000000ll // 200G cycles ~= 100s
#define NUM_WAIT_NANOSECONDS 500
#define NUM_WAIT_CYCLES_TIMES_64 16
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
// Make CLion CUDA indexing work
#ifdef __CLION_IDE__
#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier)
#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier)
#endif
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_RECV_TOKENS 256
#define DEFAULT_NUM_MAX_RDMA_CHUNKED_SEND_TOKENS 6
#define DEFAULT_NUM_MAX_RDMA_CHUNKED_RECV_TOKENS 256
static constexpr int32_t kWarpSize = 64;
// For ROCm equals to half the wave size or Nvidia warp size
static constexpr int32_t kEmulatedWarpSize = kWarpSize / 2;
static constexpr uint64_t kFullWarpMask = 0xffffffffffffffff;
static constexpr uint64_t kFirstHalfMask = 0x00000000ffffffff;
static constexpr uint64_t kSecondHalfMask = 0xffffffff00000000;
template <typename T> constexpr inline __host__ __device__ T DIVUP(const T &x, const T &y) {
return (((x) + ((y) -1)) / (y));
}
template <typename T> inline __host__ __device__ T ALIGN(T a, T b) {
return DIVUP<T>(a, b) * b;
}
// Remove Torch restrictions
#ifdef __CUDA_NO_HALF_CONVERSIONS__
......@@ -43,39 +64,7 @@
#undef __CUDA_NO_BFLOAT162_OPERATORS__
#endif
#include <cstdint>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#ifndef DISABLE_SM90_FEATURES
#include <cuda_fp8.h>
#else
// Ampere does not support FP8 features
#define __NV_E4M3 0
#define __NV_E5M2 1
typedef int __nv_fp8_interpretation_t;
typedef int __nv_fp8x4_e4m3;
typedef uint8_t __nv_fp8_storage_t;
#endif
namespace deep_ep {
#ifndef TOPK_IDX_BITS
#define TOPK_IDX_BITS 64
#endif
#define INT_BITS_T2(bits) int##bits##_t
#define INT_BITS_T(bits) INT_BITS_T2(bits)
typedef INT_BITS_T(TOPK_IDX_BITS) topk_idx_t; // int32_t or int64_t
#undef INT_BITS_T
#undef INT_BITS_T2
} // namespace deep_ep
#ifndef DISABLE_NVSHMEM
#include <nvshmem.h>
#include <nvshmemx.h>
#include <infiniband/mlx5dv.h>
#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
#include <device_host_transport/nvshmem_common_ibgda.h>
// Remove Torch restrictions for HIP
#ifdef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_OPERATORS__
#endif
......@@ -24,9 +24,9 @@ public:
#ifndef CUDA_CHECK
#define CUDA_CHECK(cmd) \
do { \
cudaError_t e = (cmd); \
if (e != cudaSuccess) { \
throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \
hipError_t e = (cmd); \
if (e != hipSuccess) { \
throw EPException("CUDA", __FILE__, __LINE__, hipGetErrorString(e)); \
} \
} while (0)
#endif
......@@ -45,7 +45,7 @@ do { \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
__builtin_trap(); \
} \
} while (0)
#endif
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment