Commit da13c63a authored by lishen's avatar lishen
Browse files

完成低延迟接口功能

parent 09cb2b03
pgrep -f /usr/bin/python | xargs kill -9
export OMPI_MCA_pml=ucx export OMPI_MCA_pml=ucx
export OMPI_MCA_osc=ucx export OMPI_MCA_osc=ucx
export OMPI_MCA_coll_hcoll_enable=0 export OMPI_MCA_coll_hcoll_enable=0
...@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa" ...@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export ROCSHMEM_MAX_NUM_CONTEXTS=32 export ROCSHMEM_MAX_NUM_CONTEXTS=32
export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384 export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240 # export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export ROCSHMEM_HEAP_SIZE=536870912 export ROCSHMEM_HEAP_SIZE=2880100992
export PYTHONPATH=/public/home/lishen/Tmp/DeepEP:$PYTHONPATH # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
pgrep -f /usr/bin/python | xargs kill -9
export OMPI_MCA_pml=ucx export OMPI_MCA_pml=ucx
export OMPI_MCA_osc=ucx export OMPI_MCA_osc=ucx
export OMPI_MCA_coll_hcoll_enable=0 export OMPI_MCA_coll_hcoll_enable=0
...@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa" ...@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export ROCSHMEM_MAX_NUM_CONTEXTS=32 export ROCSHMEM_MAX_NUM_CONTEXTS=32
export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384 export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240 # export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export ROCSHMEM_HEAP_SIZE=536870912 export ROCSHMEM_HEAP_SIZE=2880100992
export PYTHONPATH=/public/home/lishen/Tmp/DeepEP:$PYTHONPATH # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
...@@ -8,8 +8,10 @@ fi ...@@ -8,8 +8,10 @@ fi
PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])") PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])")
PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])") PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])")
INCLUDE_PATHS=${INCLUDE_PATHS:=-Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE}} ROCSHMEM_INSTALL_PREFIX=${ROCSHMEM_INSTALL_PREFIX:=$(pwd)/rocshmem_dir}
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17}
INCLUDE_PATHS=${INCLUDE_PATHS:=-Icsrc/ -I${ROCSHMEM_INSTALL_PREFIX}/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE}}
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/runtime.cu -o build_/runtime.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/runtime.cu -o build_/runtime.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/layout.cu -o build_/layout.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/layout.cu -o build_/layout.o ${COMPILE_OPTIONS}
...@@ -18,7 +20,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o ...@@ -18,7 +20,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode_ll.cu -o build_/internode_ll.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode_ll.cu -o build_/internode_ll.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/deep_ep.cu -o build_/deep_ep.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/deep_ep.cu -o build_/deep_ep.o ${COMPILE_OPTIONS}
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L$(pwd)/rocshmem_dir/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,$(pwd)/rocshmem_dir/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5 hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L${ROCSHMEM_INSTALL_PREFIX}/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,$${ROCSHMEM_INSTALL_PREFIX}/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5
# build whl # build whl
echo "Using Python: $(which python3)" echo "Using Python: $(which python3)"
......
...@@ -136,7 +136,7 @@ struct LowLatencyLayout { ...@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) { int num_ranks, int num_experts) {
const int num_scales = hidden / 128; const int num_scales = hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL;
// Dispatch and combine layout: // Dispatch and combine layout:
// - 2 symmetric odd/even send buffer // - 2 symmetric odd/even send buffer
......
...@@ -42,7 +42,6 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ ...@@ -42,7 +42,6 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
num_rdma_ranks = ::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_rdma_ranks = ::max(1, num_ranks / NUM_MAX_NVL_PEERS),
num_nvl_ranks = ::min(num_ranks, NUM_MAX_NVL_PEERS); num_nvl_ranks = ::min(num_ranks, NUM_MAX_NVL_PEERS);
#ifdef DISABLE_ROCSHMEM #ifdef DISABLE_ROCSHMEM
EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and
"rocSHMEM is disabled during compilation, please install rocSHMEM by " "rocSHMEM is disabled during compilation, please install rocSHMEM by "
...@@ -269,8 +268,11 @@ void Buffer::sync(const std::vector<int> &device_ ...@@ -269,8 +268,11 @@ void Buffer::sync(const std::vector<int> &device_
// Allocate // Allocate
rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES); rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES);
// Clean buffer (mainly for low-latency mode) auto hip_check = hipMemset(rdma_buffer_ptr, 0, num_rdma_bytes);
CUDA_CHECK(hipMemset(rdma_buffer_ptr, 0, num_rdma_bytes)); if(hip_check != hipSuccess) {
printf("Error in hipMemset. Perhaps the value of ROCSHMEM_HEAP_SIZE needs to be greater than num_rdma_bytes(%ld)\n", num_rdma_bytes);
CUDA_CHECK(hip_check);
}
// Allocate and clean shrink buffer // Allocate and clean shrink buffer
if (enable_shrink) { if (enable_shrink) {
...@@ -1105,7 +1107,6 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te ...@@ -1105,7 +1107,6 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
send_rdma_head, send_rdma_head,
send_nvl_head, send_nvl_head,
event}; event};
#else #else
EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install rocSHMEM by " EP_HOST_ASSERT(false and "rocSHMEM is disabled during compilation, please install rocSHMEM by "
"following docs/install_dependencies.md"); "following docs/install_dependencies.md");
...@@ -1271,7 +1272,6 @@ Buffer::internode_combine( ...@@ -1271,7 +1272,6 @@ Buffer::internode_combine(
} }
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
...@@ -1282,31 +1282,18 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int ...@@ -1282,31 +1282,18 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
auto offset = reinterpret_cast<int64_t>(ptr) - reinterpret_cast<int64_t>(rdma_buffer_ptr); auto offset = reinterpret_cast<int64_t>(ptr) - reinterpret_cast<int64_t>(rdma_buffer_ptr);
EP_HOST_ASSERT(0 <= offset and offset + num_bytes <= num_rdma_bytes); EP_HOST_ASSERT(0 <= offset and offset + num_bytes <= num_rdma_bytes);
}; };
check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int64_t)); check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int));
check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int64_t)); check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int));
internode_ll::clean_low_latency_buffer(clean_meta_0.first, internode_ll::clean_low_latency_buffer(clean_meta_0.first, clean_meta_0.second,
clean_meta_0.second, clean_meta_1.first, clean_meta_1.second,
clean_meta_1.first,
clean_meta_1.second,
rank,
num_ranks,
mask_buffer_ptr,
sync_buffer_ptr,
at::hip::getCurrentHIPStreamMasqueradingAsCUDA()); at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
#else
EP_HOST_ASSERT(false and "ROCSHMEM is disabled during compilation");
#endif
} }
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
std::optional<EventHandle>, std::optional<std::function<void()>>> Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
Buffer::low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_idx, int num_max_dispatch_tokens_per_rank, int num_experts,
const std::optional<torch::Tensor> &cumulative_local_expert_recv_stats, bool use_fp8, bool async, bool return_recv_hook) {
const std::optional<torch::Tensor> &dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8,
bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook) {
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
// Tensor checks // Tensor checks
...@@ -1318,99 +1305,62 @@ Buffer::low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_i ...@@ -1318,99 +1305,62 @@ Buffer::low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_i
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(num_experts % num_ranks == 0); EP_HOST_ASSERT(num_experts % num_ranks == 0);
// Diagnosis tensors
if (cumulative_local_expert_recv_stats.has_value()) {
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous());
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks);
}
if (dispatch_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous());
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks);
}
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1)); auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_topk = static_cast<int>(topk_idx.size(1)); auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
auto num_local_experts = num_experts / num_ranks; int num_local_experts = num_experts / num_ranks;
// Buffer control // Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx]; auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; // 双buffer操作 auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
// Buffer control
LowLatencyLayout nvl_layout(nvl_buffer_ptrs[nvl_rank], num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(nvl_layout.total_bytes <= num_rdma_bytes);
auto nvl_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1];
auto nvl_next_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1];
auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Wait previous tasks to be finished // Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream // NOTES: the hook mode will always use the default stream
auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
auto launch_stream = return_recv_hook ? compute_stream : comm_stream; auto launch_stream = return_recv_hook ? compute_stream : comm_stream;
EP_HOST_ASSERT(not(async and return_recv_hook)); EP_HOST_ASSERT(not (async and return_recv_hook));
if (not return_recv_hook) if (not return_recv_hook)
stream_wait(launch_stream, compute_stream); stream_wait(launch_stream, compute_stream);
// Allocate packed tensors // Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16)); x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fnuz: torch::kBFloat16));
auto packed_recv_src_info = auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Allocate column-majored scales // Allocate column-majored scales
auto packed_recv_x_scales = std::optional<torch::Tensor>(); auto packed_recv_x_scales = std::optional<torch::Tensor>();
void* packed_recv_x_scales_ptr = nullptr; float* packed_recv_x_scales_ptr = nullptr;
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
// TODO: support unaligned cases
EP_HOST_ASSERT(hidden % 512 == 0);
if (use_fp8) { if (use_fp8) {
if (not use_ue8m0) { EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank}, packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else {
EP_HOST_ASSERT(round_scale);
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA));
}
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr<float>();
} }
// Kernel launch // Kernel launch
auto next_clean_meta = next_buffer.clean_meta(); auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) { auto launcher = [=](int phases) {
internode_ll::dispatch( internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
packed_recv_x.data_ptr(), packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_x_scales_ptr,
packed_recv_src_info.data_ptr<int>(),
packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(), packed_recv_count.data_ptr<int>(),
global_atomic_counter.data_ptr<int>(), global_atomic_counter.data_ptr<int>(),
mask_buffer_ptr, buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
buffer.dispatch_rdma_recv_data_buffer,
buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer, buffer.dispatch_rdma_send_buffer,
x.data_ptr(), x.data_ptr(), topk_idx.data_ptr<int64_t>(),
topk_idx.data_ptr<int64_t>(), next_clean_meta.first, next_clean_meta.second,
next_clean_meta.first, num_tokens, hidden, num_max_dispatch_tokens_per_rank,
next_clean_meta.second, num_topk, num_experts, rank, num_ranks, use_fp8,
num_tokens, workspace, launch_stream, phases);
hidden,
num_max_dispatch_tokens_per_rank,
num_topk,
num_experts,
rank,
num_ranks,
use_fp8,
round_scale,
use_ue8m0,
workspace,
num_device_sms,
launch_stream,
phases);
}; };
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
...@@ -1431,20 +1381,14 @@ Buffer::low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_i ...@@ -1431,20 +1381,14 @@ Buffer::low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_i
// Return values // Return values
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
#else
EP_HOST_ASSERT(false and "ROCSHMEM is disabled during compilation");
return {};
#endif
} }
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) { const std::optional<torch::Tensor>& out) {
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
// Tensor checks // Tensor checks
...@@ -1463,29 +1407,27 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1463,29 +1407,27 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
if (combine_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous());
EP_HOST_ASSERT(combine_wait_recv_cost_stats->size(0) == num_ranks);
}
auto hidden = static_cast<int>(x.size(2)); auto hidden = static_cast<int>(x.size(2));
auto num_topk = static_cast<int>(topk_weights.size(1)); auto num_local_experts = num_experts / num_ranks, num_topk = static_cast<int>(topk_weights.size(1));
auto num_combined_tokens = static_cast<int>(topk_weights.size(0)); auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Buffer control // Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx]; auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
// Buffer control
LowLatencyLayout nvl_layout(nvl_buffer_ptrs[nvl_rank], num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(nvl_layout.total_bytes <= num_rdma_bytes);
auto nvl_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1];
auto nvl_next_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1];
// Wait previous tasks to be finished // Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream // NOTES: the hook mode will always use the default stream
auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
auto launch_stream = return_recv_hook ? compute_stream : comm_stream; auto launch_stream = return_recv_hook ? compute_stream : comm_stream;
EP_HOST_ASSERT(not(async and return_recv_hook)); EP_HOST_ASSERT(not (async and return_recv_hook));
if (not return_recv_hook) if (not return_recv_hook)
stream_wait(launch_stream, compute_stream); stream_wait(launch_stream, compute_stream);
...@@ -1504,32 +1446,16 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1504,32 +1446,16 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
auto next_clean_meta = next_buffer.clean_meta(); auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) { auto launcher = [=](int phases) {
internode_ll::combine(combined_x.data_ptr(), internode_ll::combine(combined_x.data_ptr(),
buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_send_buffer, buffer.combine_rdma_send_buffer,
x.data_ptr(), x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
topk_idx.data_ptr<int64_t>(), src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(),
layout_range.data_ptr<int64_t>(),
global_atomic_counter.data_ptr<int>(), global_atomic_counter.data_ptr<int>(),
mask_buffer_ptr, next_clean_meta.first, next_clean_meta.second,
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
next_clean_meta.first, num_topk, num_experts, rank, num_ranks,
next_clean_meta.second, workspace, launch_stream,
num_combined_tokens, phases, zero_copy);
hidden,
num_max_dispatch_tokens_per_rank,
num_topk,
num_experts,
rank,
num_ranks,
use_logfmt,
workspace,
num_device_sms,
launch_stream,
phases,
zero_copy);
}; };
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
...@@ -1550,49 +1476,19 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1550,49 +1476,19 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
// Return values // Return values
return {combined_x, event, recv_hook}; return {combined_x, event, recv_hook};
#else
EP_HOST_ASSERT(false and "ROCSHMEM is disabled during compilation");
return {};
#endif
} }
torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const { torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
#ifndef DISABLE_ROCSHMEM
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
auto buffer = layout.buffers[low_latency_buffer_idx]; auto buffer = layout.buffers[low_latency_buffer_idx];
auto dtype = torch::kBFloat16; auto dtype = torch::kBFloat16;
auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16)); auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));
// buffer.num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(hip_bfloat16);
EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0); EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0);
return torch::from_blob(buffer.combine_rdma_send_buffer_data_start, return torch::from_blob(buffer.combine_rdma_send_buffer_data_start,
{num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, {num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
{num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1}, {num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1},
torch::TensorOptions().dtype(dtype).device(torch::kCUDA)); torch::TensorOptions().dtype(dtype).device(torch::kCUDA));
#else
EP_HOST_ASSERT(false and "ROCSHMEM is disabled during compilation");
return {};
#endif
}
void Buffer::low_latency_update_mask_buffer(int rank_to_mask, bool mask) {
EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled");
EP_HOST_ASSERT(rank_to_mask >= 0 and rank_to_mask < num_ranks);
internode_ll::update_mask_buffer(mask_buffer_ptr, rank_to_mask, mask, at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
void Buffer::low_latency_query_mask_buffer(const torch::Tensor& mask_status) {
EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled");
EP_HOST_ASSERT(mask_status.numel() == num_ranks && mask_status.scalar_type() == torch::kInt32);
internode_ll::query_mask_buffer(
mask_buffer_ptr, num_ranks, reinterpret_cast<int*>(mask_status.data_ptr()), at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
void Buffer::low_latency_clean_mask_buffer() {
EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled");
internode_ll::clean_mask_buffer(mask_buffer_ptr, num_ranks, at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
} }
} // namespace deep_ep } // namespace deep_ep
...@@ -1634,10 +1530,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -1634,10 +1530,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer) .def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch) .def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine) .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer) .def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);
.def("low_latency_update_mask_buffer", &deep_ep::Buffer::low_latency_update_mask_buffer)
.def("low_latency_query_mask_buffer", &deep_ep::Buffer::low_latency_query_mask_buffer)
.def("low_latency_clean_mask_buffer", &deep_ep::Buffer::low_latency_clean_mask_buffer);
// m.def("is_sm90_compiled", deep_ep::is_sm90_compiled); // m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
// m.attr("int64_t") = py::cast(c10::CppTypeToScalarType<deep_ep::int64_t>::value); // m.attr("int64_t") = py::cast(c10::CppTypeToScalarType<deep_ep::int64_t>::value);
......
...@@ -26,6 +26,9 @@ private: ...@@ -26,6 +26,9 @@ private:
void *buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; void *buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void **buffer_ptrs_gpu = nullptr; void **buffer_ptrs_gpu = nullptr;
void* nvl_buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void** nvl_buffer_ptrs_gpu = nullptr;
// NVSHMEM Buffer // NVSHMEM Buffer
int64_t num_rdma_bytes; int64_t num_rdma_bytes;
void *rdma_buffer_ptr = nullptr; void *rdma_buffer_ptr = nullptr;
...@@ -171,31 +174,19 @@ public: ...@@ -171,31 +174,19 @@ public:
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden,
int num_experts); int num_experts);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_idx, int num_max_dispatch_tokens_per_rank, int num_experts,
const std::optional<torch::Tensor> &cumulative_local_expert_recv_stats, bool use_fp8, bool async, bool return_recv_hook);
const std::optional<torch::Tensor> &dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8,
bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor &x, const torch::Tensor &topk_idx, low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor &topk_weights, const torch::Tensor &src_info, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const torch::Tensor &layout_range, int num_max_dispatch_tokens_per_rank, int num_experts,
const std::optional<torch::Tensor> &combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor> &out = std::nullopt); const std::optional<torch::Tensor>& out = std::nullopt);
torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank,
int hidden, int num_experts) const;
void low_latency_update_mask_buffer(int rank_to_mask, bool mask);
void low_latency_query_mask_buffer(const torch::Tensor& mask_status);
void low_latency_clean_mask_buffer(); torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
}; };
} // namespace deep_ep } // namespace deep_ep
...@@ -134,43 +134,32 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights, ...@@ -134,43 +134,32 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
// Internode low-latency kernels // Internode low-latency kernels
namespace internode_ll { namespace internode_ll {
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1, int64_t* clean_1, int num_clean_int_1,
int rank, int num_ranks, hipStream_t stream);
int* mask_buffer, int* sync_buffer, hipStream_t stream);
void dispatch(void* packed_recv_x, void* packed_recv_x_scales, void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter, int* global_atomic_counter,
int* mask_buffer, int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x, void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx, const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
bool use_fp8, bool round_scale, bool use_ue8m0, void* workspace, hipStream_t stream, int phases);
void* workspace, int num_device_sms, hipStream_t stream, int phases);
void combine(void* combined_x, void combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range, const int* src_info, const int64_t* layout_range,
int* global_atomic_counter, int* global_atomic_counter,
int* mask_buffer, int64_t* combine_wait_recv_cost_stats, int64_t* next_clean, int num_next_clean_int,
int64_t* next_clean, int num_next_clean_int, int num_combined_tokens, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks,
int num_topk, int num_experts, int rank, int num_ranks, bool use_logfmt, void* workspace, hipStream_t stream,
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy); int phases, bool zero_copy);
void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* output_mask_tensor, hipStream_t stream);
void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask, hipStream_t stream);
void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, hipStream_t stream);
} // namespace internode_ll } // namespace internode_ll
} // namespace deep_ep } // namespace deep_ep
...@@ -5,63 +5,64 @@ ...@@ -5,63 +5,64 @@
#include "utils.cuh" #include "utils.cuh"
// #include <cooperative_groups.h> // #include <cooperative_groups.h>
#include <iostream> #include <iostream>
#include "hip/hip_runtime.h"
// low latency+RocSHMEM has issue with CTX. // low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX #define ROCM_DISABLE_CTX
#ifndef DISABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp> #include <rocshmem/rocshmem.hpp>
#include <rocshmem/rocshmem_COLL.hpp>
using namespace rocshmem;
namespace deep_ep { namespace deep_ep {
namespace internode_ll { namespace internode_ll {
template <bool use_warp_sync = false> template <typename dtype_a_t, typename dtype_b_t>
__forceinline__ __device__ bool is_rank_masked(int* mask_buffer_ptr, int rank) { __device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {
if (mask_buffer_ptr == nullptr) { EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
return false; dtype_b_t packed;
} auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
if constexpr (use_warp_sync) { unpacked_ptr[0] = x, unpacked_ptr[1] = y;
return shfl_sync(ld_acquire_global(mask_buffer_ptr + rank), 0) != 0; return packed;
} else {
return ld_acquire_global(mask_buffer_ptr + rank) != 0;
}
} }
__device__ void grid_barrier(int* global_counter, int num_blocks) { __device__ void grid_barrier(int* global_counter, int num_blocks) {
volatile int ret; volatile int ret;
__syncthreads(); __syncthreads();
memory_fence_gpu(); __threadfence();
if (threadIdx.x == 0 ) { if (threadIdx.x == 0 ) {
ret = atomicAdd((int*)&global_counter[0], 1); // ret = __hip_atomic_fetch_add(&global_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
ret = atomicAdd(&global_counter[0], 1);
} }
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
while (ld_relaxed_global(global_counter) != num_blocks); while (__hip_atomic_load(global_counter, __ATOMIC_RELAXED,__HIP_MEMORY_SCOPE_AGENT) != num_blocks);
} }
__syncthreads(); __syncthreads();
} }
template <typename dtype_t>
__host__ __device__ dtype_t ceil_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;
}
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
x = unpacked_ptr[0], y = unpacked_ptr[1];
}
template <int kNumThreads> __launch_bounds__(kNumThreads, 1) template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
__global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1, int64_t* clean_1, int num_clean_int_1) {
int rank, int num_ranks,
int* mask_buffer_ptr, int* sync_buffer_ptr) {
auto thread_id = static_cast<int>(threadIdx.x);
// Barrier before cleaning (in case of unfinished chunked EP) // Barrier before cleaning (in case of unfinished chunked EP)
if (sync_buffer_ptr == nullptr) { if (threadIdx.x == 0)
// rocshmem::rocshmem_barrier_all_wg();
if (thread_id == 0)
rocshmem::rocshmem_barrier_all(); rocshmem::rocshmem_barrier_all();
} else {
// barrier<kNumThreads>(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr);
EP_DEVICE_ASSERT(0);
}
// Clean // Clean
auto thread_id = static_cast<int>(threadIdx.x);
#pragma unroll #pragma unroll
for (int i = thread_id; i < num_clean_int_0; i += kNumThreads) for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
clean_0[i] = 0; clean_0[i] = 0;
...@@ -70,59 +71,33 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, ...@@ -70,59 +71,33 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
clean_1[i] = 0; clean_1[i] = 0;
// Barrier after cleaning (make sure low-latency mode work // Barrier after cleaning (make sure low-latency mode work
if (sync_buffer_ptr == nullptr) { if (threadIdx.x == 0)
// rocshmem::rocshmem_barrier_all_wg();
if (thread_id == 0)
rocshmem::rocshmem_barrier_all(); rocshmem::rocshmem_barrier_all();
} else {
// barrier<kNumThreads>(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr);
EP_DEVICE_ASSERT(0);
}
} }
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1, int64_t* clean_1, int num_clean_int_1,
int rank, int num_ranks,
int* mask_buffer_ptr, int* sync_buffer_ptr,
hipStream_t stream) { hipStream_t stream) {
constexpr int kNumThreads = 256; constexpr int kNumThreads = 256;
SETUP_LAUNCH_CONFIG(1, kNumThreads, stream); SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
LAUNCH_KERNEL(&cfg, clean_low_latency_buffer<kNumThreads>, LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, clean_low_latency_buffer<kNumThreads>,
clean_0, num_clean_int_0, clean_1, num_clean_int_1, clean_0, num_clean_int_0, clean_1, num_clean_int_1);
rank, num_ranks,
mask_buffer_ptr, sync_buffer_ptr);
} }
template <bool kUseFP8, bool kUseUE8M0, int kHidden> template <bool kUseFP8, int kHidden>
__launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x, __global__ __launch_bounds__(16 * kWarpSize, 1) void
void* packed_recv_x_scales, dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int64_t* packed_recv_layout_range,
int* packed_recv_count, int* packed_recv_count,
int* global_atomic_counter, int* global_atomic_counter,
int* mask_buffer_ptr, void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
int* cumulative_local_expert_recv_stats, const void* x, const int64_t* topk_idx,
int64_t* dispatch_wait_recv_cost_stats, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
void* rdma_recv_x, int64_t* next_clean, int num_next_clean_int,
int64_t* rdma_recv_count, int num_tokens, int num_max_dispatch_tokens_per_rank,
void* rdma_x, int num_topk, int num_experts, int rank, int num_ranks,
const void* x, int num_warp_groups, int num_warps_per_group, int phases) {
const int64_t* topk_idx,
int* atomic_counter_per_expert,
int* atomic_finish_counter_per_expert,
int64_t* next_clean,
int num_next_clean_int,
int num_tokens,
int num_max_dispatch_tokens_per_rank,
int num_topk,
int num_experts,
int rank,
int num_ranks,
int num_warp_groups,
int num_warps_per_group,
bool round_scale,
int phases) {
#if !defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
__shared__ rocshmem::rocshmem_ctx_t ctx; __shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx); rocshmem::rocshmem_wg_ctx_create(0, &ctx);
...@@ -136,33 +111,21 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x, ...@@ -136,33 +111,21 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
const auto num_local_experts = num_experts / num_ranks; const auto num_local_experts = num_experts / num_ranks;
const auto warp_group_id = warp_id / num_warps_per_group; const auto warp_group_id = warp_id / num_warps_per_group;
const auto sub_warp_id = warp_id % num_warps_per_group; const auto sub_warp_id = warp_id % num_warps_per_group;
// 每个warp处理一个expert
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// May extract UE8M0 from the scales
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
// FP8 staffs // FP8 staffs
constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL; constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL;
const int num_scales = kHidden / kNumPerChannels; const int num_scales = kHidden / kNumPerChannels;
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16)); const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
const size_t hidden_int4 = hidden_bytes / sizeof(int4); const size_t hidden_int4 = hidden_bytes / sizeof(int4);
// Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales // Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use // NOTES: currently we have 3 reserved int fields for future use
using vec_t = std::conditional_t<kUseFP8, int2, int4>; using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16))); const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
// Expert counts
constexpr int kNumMaxWarpGroups = 16; // 每个kernel最多warp group数量,即每个block负责的专家数
__shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
#ifdef USE_ROCM
// 用于同步
// 16 is the max possible number of warps in AMD GPUs // 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize; constexpr int kMaxNumWarps = 1024 / kWarpSize;
constexpr int num_sync_large_iteration = kMaxNumWarps ; constexpr int num_sync_large_iteration = kMaxNumWarps ;
...@@ -173,57 +136,57 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x, ...@@ -173,57 +136,57 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
sync_large_warp_counters[i] = 0; sync_large_warp_counters[i] = 0;
} }
__syncthreads(); __syncthreads();
#endif
// Sending phase,如果没有发送任务,则直接跳到接收阶段 // Expert counts
constexpr int kNumMaxWarpGroups = 1024 / kWarpSize;
__shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
// Sending phase
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_DISPATCH_RECV; goto LOW_LATENCY_DISPATCH_RECV;
// There are 2 kinds of warps in this part: // There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens // 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information // 2. The last warp for reading `topk_idx` and count for per-expert information
if (warp_id < num_warps - 1) { if (warp_id < num_warps) {
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16); // 128/16 = 8 constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16);
EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerRead) == 0, "Invalid hidden"); EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization"); EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * kWarpSize; const auto num_threads = (num_warps - 1) * kWarpSize;
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
const auto x_int4 = static_cast<const int4*>(x) + token_idx * hidden_bf16_int4; const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
const auto rdma_x_src_idx = reinterpret_cast<int*>(static_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg); const auto rdma_x_src_idx = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4)); const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes); const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
// Overlap top-k index read and source token index writes // Overlap top-k index read and source token index write
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
// FP8 cast // FP8 cast
EP_STATIC_ASSERT(hidden_bf16_int4 % kWarpSize == 0, "Must use the full warp to reduce");
#pragma unroll #pragma unroll
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// Read // Read
auto int4_value = __ldg(x_int4 + i); auto int4_value = __ldg(x_int4 + i);
if constexpr (kUseFP8) { if (kUseFP8) {
// Calculate local amax // Calculate local amax
auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value); auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead]; float fp32_values[kNumElemsPerRead];
float amax = kFP8Margin, scale, scale_inv; float amax = kFP8Margin, scale, scale_inv;
#pragma unroll #pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++j) { for (int j = 0; j < kNumElemsPerRead; ++ j) {
fp32_values[j] = static_cast<float>(bf16_values[j]); fp32_values[j] = static_cast<float>(bf16_values[j]);
amax = fmaxf(amax, fabsf(fp32_values[j])); amax = fmaxf(amax, fabsf(fp32_values[j]));
} }
// Reduce amax and scale // Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization"); EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
amax = warp_reduce_max<16>(amax); amax = warp_reduce_max<16>(amax);
calculate_fp8_scales(amax, scale, scale_inv, round_scale); calculate_fp8_scales</*round_scale*/false>(amax, scale, scale_inv);
if (lane_id % 16 == 0) if (lane_id % 16 == 0)
rdma_x_scales[i * kNumElemsPerRead / FP8_QUANTIZATION_NUM_PER_CHANNEL] = scale_inv; rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
// Cast into send buffer // Cast into send buffer
vec_t int2_value; vec_t int2_value;
...@@ -240,44 +203,38 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x, ...@@ -240,44 +203,38 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
} }
} }
__syncthreads(); __syncthreads();
// Issue IBGDA sends // Issue IBGDA sends
if (dst_expert_idx >= 0) { if (dst_expert_idx >= 0) {
int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
slot_idx = shfl_sync(slot_idx, 0); slot_idx = shfl_sync(slot_idx, 0);
const int dst_rank = dst_expert_idx / num_local_experts; const auto dst_rank = dst_expert_idx / num_local_experts;
const int dst_expert_local_idx = dst_expert_idx % num_local_experts; const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx); const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg; rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
if (dst_rank != rank) { if (dst_rank != rank) {
#if !defined(ROCM_DISABLE_CTX) rocshmem::rocshmem_schar_put_nbi_wave(reinterpret_cast<signed char*>(dst_ptr),
rocshmem::rocshmem_ctx_schar_put_nbi_wave(ctx, reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank);
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank);
rocshmem::rocshmem_ctx_quiet(ctx);
#else
rocshmem::rocshmem_schar_put_nbi_wave(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank);
rocshmem::rocshmem_fence(); rocshmem::rocshmem_fence();
#endif
} else { } else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr); const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr); const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); UNROLLED_WARP_COPY_LL(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
} }
// Increase counter after finishing // Increase counter after finishing
syncwarp(); syncwarp();
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
} }
} }
} else if (warp_id == num_warps - 1) { }
if (warp_id == num_warps - 1) {
EP_DEVICE_ASSERT(num_sms > 1); EP_DEVICE_ASSERT(num_sms > 1);
if (sm_id == 0) { if (sm_id == 0) {
// The first SM is also responsible for checking QPs
// The first SM is also responsible for cleaning the next buffer // The first SM is also responsible for cleaning the next buffer
#pragma unroll #pragma unroll
for (int i = lane_id; i < num_next_clean_int; i += kWarpSize) for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
...@@ -289,7 +246,6 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x, ...@@ -289,7 +246,6 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
for (int i = lane_id; i < num_experts; i += kWarpSize) for (int i = lane_id; i < num_experts; i += kWarpSize)
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
} }
// This SM should be responsible for some destination experts, read `topk_idx` for them // This SM should be responsible for some destination experts, read `topk_idx` for them
int expert_count[kNumMaxWarpGroups] = {0}; int expert_count[kNumMaxWarpGroups] = {0};
const auto expert_begin_idx = sm_id * num_warp_groups; const auto expert_begin_idx = sm_id * num_warp_groups;
...@@ -300,12 +256,12 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x, ...@@ -300,12 +256,12 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
for (int i = lane_id; i < num_tokens * num_topk; i += kWarpSize) { for (int i = lane_id; i < num_tokens * num_topk; i += kWarpSize) {
auto idx = static_cast<int>(__ldg(topk_idx + i)); auto idx = static_cast<int>(__ldg(topk_idx + i));
if (idx >= expert_begin_idx and idx < expert_end_idx) if (idx >= expert_begin_idx and idx < expert_end_idx)
expert_count[idx - expert_begin_idx]++; expert_count[idx - expert_begin_idx] ++;
} }
// Warp reduce // Warp reduce
#pragma unroll #pragma unroll
for (int i = expert_begin_idx; i < expert_end_idx; ++i) { for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
if (lane_id == 0) { if (lane_id == 0) {
shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
...@@ -314,6 +270,7 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x, ...@@ -314,6 +270,7 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
} }
} }
//revert sync_large_warp_counters to 0 for next sync
__syncthreads(); __syncthreads();
// Issue count sends // Issue count sends
...@@ -324,17 +281,10 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x, ...@@ -324,17 +281,10 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
// Wait local sends issued and send expert counts // Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (not is_rank_masked(mask_buffer_ptr, dst_rank)) {
auto dst_ptr = reinterpret_cast<int64_t*>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank);
if (dst_rank != rank) { if (dst_rank != rank) {
#if !defined(ROCM_DISABLE_CTX) rocshmem::rocshmem_long_atomic_add( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
rocshmem::rocshmem_ctx_long_atomic_add(ctx, dst_ptr, -num_tokens_sent - 1, dst_rank);
#else
rocshmem::rocshmem_long_atomic_add(dst_ptr, -num_tokens_sent - 1, dst_rank);
#endif
} else { } else {
st_release_sys_global(dst_ptr, -num_tokens_sent - 1); st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
}
} }
// Clean workspace for next use // Clean workspace for next use
...@@ -347,10 +297,8 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x, ...@@ -347,10 +297,8 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
} }
syncwarp(); syncwarp();
// Receiving phase
// Receiving phase LOW_LATENCY_DISPATCH_RECV:
LOW_LATENCY_DISPATCH_RECV:
// 如果没有接收直接返回
if ((phases & LOW_LATENCY_RECV_PHASE) == 0) if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return; return;
...@@ -363,85 +311,40 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x, ...@@ -363,85 +311,40 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
if (responsible_expert_idx < num_experts) { if (responsible_expert_idx < num_experts) {
const auto src_rank = responsible_expert_idx / num_local_experts; const auto src_rank = responsible_expert_idx / num_local_experts;
const auto local_expert_idx = responsible_expert_idx % num_local_experts; const auto local_expert_idx = responsible_expert_idx % num_local_experts;
const auto rdma_recv_x_uint8 = static_cast<uint8_t*>(rdma_recv_x) + const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
const auto recv_x_int4 = const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
static_cast<int4*>(packed_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
const auto num_aligned_scales = ALIGN<int>(num_scales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
// Shared between sub-warps in warp groups // Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
// Wait tokens to arrive // Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0 // NOTES: using sub-warp 1 to overlap with sub-warp 0
int64_t num_recv_tokens; int num_recv_tokens, recv_token_begin_idx;
int recv_token_begin_idx; EP_DEVICE_ASSERT(num_warps_per_group > 1);
EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15);
if (sub_warp_id == 1 and lane_id == 0) { if (sub_warp_id == 1 and lane_id == 0) {
auto start_time = wall_clock64(); while ((num_recv_tokens = ld_acquire_global(reinterpret_cast<int*>(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0);
int64_t wait_recv_cost = 0;
int offset = local_expert_idx * num_ranks + src_rank;
if (not is_rank_masked(mask_buffer_ptr, src_rank)) {
while ((wait_recv_cost = wall_clock64() - start_time) <= NUM_TIMEOUT_CYCLES) { // not timeout
if((num_recv_tokens = ld_acquire_global(reinterpret_cast<int64_t*>(
rdma_recv_count + local_expert_idx * num_ranks + src_rank))) != 0) {
break;
}
}
}
// Mask rank if timeout
if (wait_recv_cost > NUM_TIMEOUT_CYCLES) {
printf("Warning: DeepEP timeout for dispatch receive, rank %d, local_expert_idx %d, src_rank %d\n",
rank,
local_expert_idx,
src_rank);
if (mask_buffer_ptr == nullptr)
trap();
atomicExch(mask_buffer_ptr + src_rank, 1);
}
// Do not receive tokens if rank timeout or masked
if (num_recv_tokens == 0)
num_recv_tokens = -1;
#if 1
num_recv_tokens = -num_recv_tokens - 1; num_recv_tokens = -num_recv_tokens - 1;
int num_recv_tokens_int32 = static_cast<int>(num_recv_tokens); recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens_int32);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens_int32;
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens_int32, recv_token_begin_idx); recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
// Add stats for diagnosis
if (cumulative_local_expert_recv_stats != nullptr)
atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens_int32);
if (dispatch_wait_recv_cost_stats != nullptr) {
atomicAdd(reinterpret_cast<uint64_t*>(dispatch_wait_recv_cost_stats + src_rank), static_cast<uint64_t>(wait_recv_cost));
}
#endif
} }
#if 1
#ifdef USE_ROCM
// no needs to reset because there is no iteration // no needs to reset because there is no iteration
if (lane_id == 0){ if (lane_id == 0){
// volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1); volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1);
} }
syncwarp(); syncwarp();
while (sync_large_warp_counters[warp_group_id] < num_warps_per_group) {}
#else
// asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
#endif
while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
num_recv_tokens = shared_num_recv_tokens[warp_group_id]; num_recv_tokens = shared_num_recv_tokens[warp_group_id];
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
...@@ -458,506 +361,308 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x, ...@@ -458,506 +361,308 @@ __launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls // NOTES: only 2 load iterations for 7K hidden with 7 unrolls
const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4)); const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// Copy scales // Copy scales
if constexpr (kUseFP8) { if (kUseFP8) {
// Equivalent CuTe layout:
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes); const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t)); const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
const auto token_idx = recv_token_begin_idx + i; const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
const auto token_stride = num_elems_per_pack; auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack; auto scale_1 = (lane_id + kWarpSize) < num_scales ? ld_nc_global(src_scales + lane_id + kWarpSize) : 0;
if (lane_id < num_scales) { lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
const auto pack_idx = lane_id / num_elems_per_pack; (lane_id + kWarpSize) < num_scales ? dst_scales[(lane_id + kWarpSize) * scale_stride] = scale_1 : 0.0f;
const auto elem_idx = lane_id % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
if (lane_id + kWarpSize < num_scales) {
const auto pack_idx = (lane_id + kWarpSize) / num_elems_per_pack;
const auto elem_idx = (lane_id + kWarpSize) % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + kWarpSize));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
} }
} }
#endif
} }
#if !defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_wg_ctx_destroy(&ctx); rocshmem::rocshmem_wg_ctx_destroy(&ctx);
#endif #endif
} }
void dispatch(void* packed_recv_x, void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
void* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_src_info,
int64_t* packed_recv_layout_range,
int* packed_recv_count, int* packed_recv_count,
int* global_atomic_counter, int* global_atomic_counter,
int* mask_buffer_ptr, void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
int* cumulative_local_expert_recv_stats, const void* x, const int64_t* topk_idx,
int64_t* dispatch_wait_recv_cost_stats, int64_t* next_clean, int num_next_clean_int,
void* rdma_recv_x, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int64_t* rdma_recv_count, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
void* rdma_x, void* workspace, hipStream_t stream, int phases) {
const void* x,
const int64_t* topk_idx,
int64_t* next_clean,
int num_next_clean_int,
int num_tokens,
int hidden,
int num_max_dispatch_tokens_per_rank,
int num_topk,
int num_experts,
int rank,
int num_ranks,
bool use_fp8,
bool round_scale,
bool use_ue8m0,
void* workspace,
int num_device_sms,
hipStream_t stream,
int phases) {
constexpr int kNumMaxTopK = 11; constexpr int kNumMaxTopK = 11;
const int num_warp_groups = DIVUP(num_experts, num_device_sms); const int num_warp_groups = ceil_div(num_experts, /*num_device_sms*/80);
EP_HOST_ASSERT(num_warp_groups <= 16); const int num_warps_per_group = 16 / num_warp_groups;
const int num_warps_per_group = 16 / num_warp_groups; // 每个kernel最大16个warp
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0); EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group); EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);
const auto num_warps = num_warp_groups * num_warps_per_group; const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_sms = DIVUP(num_experts, num_warp_groups); const auto num_sms = ceil_div(num_experts, num_warp_groups);
EP_HOST_ASSERT(num_topk <= kNumMaxTopK); EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
// Workspace checks // Workspace checks
auto atomic_counter_per_expert = static_cast<int*>(workspace); auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
#define DISPATCH_LAUNCH_CASE(hidden) \ #define DISPATCH_LAUNCH_CASE(hidden) { \
{ \ auto dispatch_func = use_fp8 ? dispatch<true, hidden> : \
auto dispatch_func = dispatch<false, false, hidden>; \ dispatch<false, hidden>; \
if(use_fp8 and not use_ue8m0) \ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
dispatch_func = dispatch<true, false, hidden>; \ packed_recv_x, packed_recv_x_scales, \
if(use_fp8 and use_ue8m0) \ packed_recv_src_info, packed_recv_layout_range, \
dispatch_func = dispatch<true, true, hidden>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, \
dispatch_func, \
packed_recv_x, \
packed_recv_x_scales, \
packed_recv_src_info, \
packed_recv_layout_range, \
packed_recv_count, \ packed_recv_count, \
global_atomic_counter, \ global_atomic_counter, \
mask_buffer_ptr, \ rdma_recv_x, rdma_recv_count, rdma_x, \
cumulative_local_expert_recv_stats, \ x, topk_idx, \
dispatch_wait_recv_cost_stats, \ atomic_counter_per_expert, atomic_finish_counter_per_expert, \
rdma_recv_x, \ next_clean, num_next_clean_int, \
rdma_recv_count, \ num_tokens, num_max_dispatch_tokens_per_rank, \
rdma_x, \ num_topk, num_experts, rank, num_ranks, \
x, \ num_warp_groups, num_warps_per_group, phases); } break
topk_idx, \
atomic_counter_per_expert, \
atomic_finish_counter_per_expert, \
next_clean, \
num_next_clean_int, \
num_tokens, \
num_max_dispatch_tokens_per_rank, \
num_topk, \
num_experts, \
rank, \
num_ranks, \
num_warp_groups, \
num_warps_per_group, \
round_scale, \
phases); \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream); SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE); SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE #undef DISPATCH_LAUNCH_CASE
} }
template <bool kUseLogFMT, int kHidden, int kNumMaxTopk, int kNumMaxUnrolls> template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
__launch_bounds__(1024, 1) __global__ void combine(void* combined_x, __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * kWarpSize, 1) void
void* rdma_recv_x, combine(void* combined_x,
int* rdma_recv_flag, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
void* rdma_send_x, const void* x, const int64_t* topk_idx, const float* topk_weights,
const void* x, const int* src_info, const int64_t* layout_range,
const int64_t* topk_idx,
const float* topk_weights,
const int* src_info,
const int64_t* layout_range,
int* global_atomic_counter, int* global_atomic_counter,
int* mask_buffer_ptr, int64_t* next_clean, int num_next_clean_int,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean,
int num_next_clean_int,
int* atomic_clean_flag, int* atomic_clean_flag,
int num_combined_tokens, int num_combined_tokens, int hidden, int num_topk,
int hidden,
int num_topk,
int num_max_dispatch_tokens_per_rank, int num_max_dispatch_tokens_per_rank,
int num_experts, int num_experts, int rank, int num_ranks,
int rank, int phases, bool zero_copy) {
int num_ranks,
int num_warp_groups,
int num_warps_per_group,
int phases,
bool zero_copy) {
#if !defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
__shared__ rocshmem::rocshmem_ctx_t ctx; __shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx); rocshmem::rocshmem_wg_ctx_create(0, &ctx);
#endif #endif
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto num_threads = static_cast<int>(blockDim.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
const auto num_local_experts = num_experts / num_ranks;
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
// Data type staffs
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(hip_bfloat16);
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// Message package
// BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
__syncthreads();
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
__shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
if (threadIdx.x==0){
#pragma unroll
for (int i = 0; i < kMaxNumWarps; ++i) {
sync_large_warp_counters[i] = 0;
}
}
__syncthreads();
// const auto sm_id = static_cast<int>(blockIdx.x); // Sending phase
// const auto num_sms = static_cast<int>(gridDim.x); if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
// const auto thread_id = static_cast<int>(threadIdx.x); goto LOW_LATENCY_COMBINE_RECV;
// const auto num_threads = static_cast<int>(blockDim.x);
// const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
// const auto num_local_experts = num_experts / num_ranks;
// const auto warp_group_id = warp_id / kNumWarpsPerGroup;
// const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
// const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
// // Data type staffs
// constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(gpu_bfloat16_t);
// const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// // Message package
// // BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
// constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(gpu_bfloat16_t);
// EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// __syncthreads();
// #ifdef USE_ROCM
// // 16 is the max possible number of warps in AMD GPUs
// constexpr int kMaxNumWarps = 1024 / kWarpSize;
// __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
// if (threadIdx.x==0){
// // printf("combine");
// #pragma unroll
// for (int i = 0; i < kMaxNumWarps; ++i) {
// sync_large_warp_counters[i] = 0;
// }
// }
// __syncthreads();
// #endif
// // Sending phase
// if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
// goto LOW_LATENCY_COMBINE_RECV;
// // Clean up next buffer
// if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
// #pragma unroll
// for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
// next_clean[i] = 0;
// // Notify before executing `int_p`
// syncwarp();
// if (lane_id == 0)
// atomic_add_release_global(atomic_clean_flag, num_experts);
// }
// // Issue IBGDA sends
// if (responsible_expert_idx < num_experts) {
// const auto dst_rank = responsible_expert_idx / num_local_experts;
// const auto local_expert_idx = responsible_expert_idx % num_local_experts;
// const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
// const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
// const auto local_x = reinterpret_cast<const int4*>(x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
// const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
// const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// // Unpack layout
// int offset, num_tokens_to_send;
// unpack2(layout, num_tokens_to_send, offset);
// // Issue IBGDA send
// for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
// const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
// const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
// const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
// // Copy directly to local rank, or copy to buffer and issue RDMA
// auto src_idx = __ldg(local_src_info + token_idx);
// const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
// const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
// if (dst_rank == rank) {
// const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
// UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
// } else {
// const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
// if (not zero_copy)
// UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
// //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(gpu_bfloat16_t), dst_rank, local_expert_idx, lane_id, token_idx - offset);
// #if defined(ROCM_DISABLE_CTX)
// internode::shmemx_int8_put_nbi_warp(
// #else
// internode::shmem_ctx_schar_put_nbi_warp(ctx,
// #endif
// reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank);
// #if defined(ROCM_DISABLE_CTX)
// internode::shmem_fence();
// #else
// internode::shmem_ctx_quiet(ctx);
// #endif
// }
// }
// // Put finishing flag
// EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
// #ifdef USE_ROCM
// if (lane_id == 0){
// volatile int ret = __hip_atomic_fetch_add(
// &sync_large_warp_counters[warp_group_id], 1,
// __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
// }
// syncwarp();
// while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup));
// #else
// asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
// #endif
// if (sub_warp_id == 1 and lane_id == 0) {
// while (ld_acquire_global(atomic_clean_flag) == 0);
// if (dst_rank != rank) {
// #ifdef USE_ROCM
// #if defined(ROCM_DISABLE_CTX)
// internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank);
// #else
// internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank);
// #endif
// #else
// nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
// #endif
// } else {
// st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
// }
// atomic_add_release_global(atomic_clean_flag, -1);
// }
// syncwarp();
// }
// // Receiving phase
// LOW_LATENCY_COMBINE_RECV:
// if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
// return;
// // Wait all ranks to arrive and notify PCIe usage
// if (responsible_expert_idx < num_experts) {
// EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
// if (sub_warp_id == 0 and lane_id == 0){
// while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0);
// }
// }
// grid_barrier(global_atomic_counter, num_sms);
// // Reduce tokens with FP8 cast
// EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
// EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
// if (thread_id < hidden_bf16_int4) {
// for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// // Read top-k indices and weights
// int reg_topk_idx[kNumMaxTopk];
// float reg_topk_weights[kNumMaxTopk];
// #pragma unroll
// for (int i = 0; i < num_topk; ++ i) {
// reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
// reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
// }
// float combined_values[kNumElemsPerInt4] = {0.0f};
// #pragma unroll
// for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// // Read from sources
// auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
// auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
// // Reduce
// auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
// const auto x_bf16 = reinterpret_cast<gpu_bfloat16_t*>(&x_vec);
// #pragma unroll
// for (int j = 0; j < kNumElemsPerInt4; ++ j)
// combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
// }
// // Write results
// int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
// auto combined_bf16 = reinterpret_cast<gpu_bfloat16_t*>(&combined_values);
// #pragma unroll
// for (int j = 0; j < kNumElemsPerInt4; ++ j)
// combined_bf16[j] = static_cast<gpu_bfloat16_t>(combined_values[j]);
// (reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
// }
// }
#if !defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
#endif
}
void combine(void* combined_x, // Clean up next buffer
void* rdma_recv_x, if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
int64_t* rdma_recv_flag, #pragma unroll
void* rdma_send_x, for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
const void* x, next_clean[i] = 0;
const int64_t* topk_idx,
const float* topk_weights,
const int* src_info,
const int64_t* layout_range,
int* global_atomic_counter,
int* mask_buffer_ptr,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean,
int num_next_clean_int,
int num_combined_tokens,
int hidden,
int num_max_dispatch_tokens_per_rank,
int num_topk,
int num_experts,
int rank,
int num_ranks,
bool use_logfmt,
void* workspace,
int num_device_sms,
hipStream_t stream,
int phases,
bool zero_copy) {
constexpr int kNumMaxTopk = 11;
const int num_warp_groups = DIVUP(num_experts, num_device_sms);
const int num_warps_per_group = 16 / num_warp_groups;
const int num_recv_per_sm = DIVUP(num_combined_tokens, num_device_sms);
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
const auto num_warps = num_warp_groups * num_warps_per_group; // Notify before executing `int_p`
const auto num_sms = max(DIVUP(num_experts, num_warp_groups), num_recv_per_sm == 0 ? 1 : DIVUP(num_combined_tokens, num_recv_per_sm)); syncwarp();
if (lane_id == 0)
atomic_add_release_global(atomic_clean_flag, num_experts);
}
// Check workspace // Issue IBGDA sends
auto atomic_clean_flag = static_cast<int*>(workspace); if (responsible_expert_idx < num_experts) {
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES); const auto dst_rank = responsible_expert_idx / num_local_experts;
EP_HOST_ASSERT(num_topk <= kNumMaxTopk); const auto local_expert_idx = responsible_expert_idx % num_local_experts;
const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
const auto local_x = reinterpret_cast<const int4*>(x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// Unpack layout
int offset, num_tokens_to_send;
unpack2(layout, num_tokens_to_send, offset);
// Issue IBGDA send
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
// Copy directly to local rank, or copy to buffer and issue RDMA
auto src_idx = __ldg(local_src_info + token_idx);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
if (dst_rank == rank) {
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
// Online cast cannot use zero-copy //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
EP_HOST_ASSERT(not(zero_copy and use_logfmt)); #if defined(ROCM_DISABLE_CTX)
EP_HOST_ASSERT(use_logfmt == 0); rocshmem::rocshmem_schar_put_nbi_wave(
#else
constexpr int kNumMaxUnrolls = 4; rocshmem::rocshmem_ctx_schar_put_nbi_wave(ctx,
#endif
#ifdef USEING_TMA reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(hip_bfloat16), dst_rank);
constexpr int kNumStages = 3;
constexpr int kMaxNumGroups = 2;
// Send buffer size
const int num_meta_bytes = hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL * 4;
const int num_send_tma_bytes = 32 * sizeof(int4) * kNumMaxUnrolls + 16;
const int smem_send_size = num_warps * (kNumStages * num_send_tma_bytes + num_meta_bytes);
// Receive buffer size
const int num_recv_tma_bytes = 16 + hidden * 2;
const int smem_recv_size = kMaxNumGroups * (kNumStages * num_recv_tma_bytes + hidden * 2 + kNumStages * num_meta_bytes * 3);
// Total requirement
const int smem_size = max(smem_send_size, smem_recv_size);
#endif
// #define COMBINE_LAUNCH_CASE(hidden) \
// { \
// auto combine_func = combine<false, hidden, kNumMaxTopk, kNumMaxUnrolls>; \
// LAUNCH_KERNEL(&cfg, \
// combine_func, \
// combined_x, \
// rdma_recv_x, \
// rdma_recv_flag, \
// rdma_send_x, \
// x, \
// topk_idx, \
// topk_weights, \
// src_info, \
// layout_range, \
// global_atomic_counter, \
// mask_buffer_ptr, \
// combine_wait_recv_cost_stats, \
// next_clean, \
// num_next_clean_int, \
// atomic_clean_flag, \
// num_combined_tokens, \
// hidden, \
// num_topk, \
// num_max_dispatch_tokens_per_rank, \
// num_experts, \
// rank, \
// num_ranks, \
// num_warp_groups, \
// num_warps_per_group, \
// phases, \
// zero_copy); \
// } \
// break
// SETUP_LAUNCH_CONFIG(num_sms, num_warps* kWarpSize, stream);
// SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
// #undef COMBINE_LAUNCH_CASE
}
template <int kNumThreads> #if defined(ROCM_DISABLE_CTX)
__launch_bounds__(kNumThreads, 1) __global__ void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* mask_tensor) { rocshmem::rocshmem_fence();
const auto num_sms = static_cast<int>(gridDim.x); #else
const auto sm_id = static_cast<int>(blockIdx.x); rocshmem::rocshmem_ctx_quiet(ctx);
const auto num_threads = num_sms * kNumThreads; #endif
const auto thread_id = sm_id * kNumThreads + static_cast<int>(threadIdx.x); }
for (int rank_id = thread_id; rank_id < num_ranks; rank_id += num_threads) {
mask_tensor[rank_id] = mask_buffer_ptr[rank_id];
} }
}
void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* mask_tensor, hipStream_t stream) { // Put finishing flag
constexpr int num_sms = 1; EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
constexpr int kNumThreads = 1024; if (lane_id == 0){
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); // volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, query_mask_buffer<kNumThreads>, mask_buffer_ptr, num_ranks, mask_tensor); volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1);
} }
syncwarp();
while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup));
if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) {
#if defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank);
#else
rocshmem::rocshmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank);
#endif
} else {
st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
}
atomic_add_release_global(atomic_clean_flag, -1);
}
syncwarp();
}
template <int kNumThreads> // Receiving phase
__launch_bounds__(kNumThreads, 1) __global__ void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask) { LOW_LATENCY_COMBINE_RECV:
const auto sm_id = static_cast<int>(blockIdx.x); if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
const auto thread_id = static_cast<int>(threadIdx.x); return;
if (sm_id == 0 && thread_id == 0) {
atomicExch(mask_buffer_ptr + rank_to_mask, mask ? 1 : 0); // Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) {
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
if (sub_warp_id == 0 and lane_id == 0){
while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0);
} }
} }
grid_barrier(global_atomic_counter, num_sms);
void update_mask_buffer(int* mask_buffer_ptr, int rank, bool mask, hipStream_t stream) { // Reduce tokens with FP8 cast
constexpr int num_sms = 1; EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
constexpr int kNumThreads = 64; EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); if (thread_id < hidden_bf16_int4) {
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, update_mask_buffer<kNumThreads>, mask_buffer_ptr, rank, mask); for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
} // Read top-k indices and weights
int reg_topk_idx[kNumMaxTopk];
float reg_topk_weights[kNumMaxTopk];
#pragma unroll
for (int i = 0; i < num_topk; ++ i) {
reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
}
template <int kNumThreads> float combined_values[kNumElemsPerInt4] = {0.0f};
__launch_bounds__(kNumThreads, 1) __global__ void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks) { #pragma unroll
auto thread_id = static_cast<int>(threadIdx.x); for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// Read from sources
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
// Reduce
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
const auto x_bf16 = reinterpret_cast<hip_bfloat16*>(&x_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
}
// Write results
int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
auto combined_bf16 = reinterpret_cast<hip_bfloat16*>(&combined_values);
#pragma unroll #pragma unroll
for (int i = thread_id; i < num_ranks; i += kNumThreads) for (int j = 0; j < kNumElemsPerInt4; ++ j)
mask_buffer_ptr[i] = 0; combined_bf16[j] = static_cast<hip_bfloat16>(combined_values[j]);
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
}
}
#if !defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
#endif
} }
void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, hipStream_t stream) { void combine(void* combined_x,
constexpr int num_sms = 1; void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
constexpr int kNumThreads = 64; const void* x, const int64_t* topk_idx, const float* topk_weights,
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); const int* src_info, const int64_t* layout_range,
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, clean_mask_buffer<kNumThreads>, mask_buffer_ptr, num_ranks); int* global_atomic_counter,
int64_t* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, hipStream_t stream,
int phases, bool zero_copy) {
constexpr int kNumWarpsPerGroup = 4;
constexpr int kNumWarpGroups = 4;
constexpr int kNumMaxTopk = 9;
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
const auto num_sms = ceil_div(num_experts, kNumWarpGroups);
// Check workspace
auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
#define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = combine<kNumWarpGroups, kNumWarpsPerGroup, hidden, kNumMaxTopk>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
global_atomic_counter, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
phases, zero_copy); } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
} }
} // namespace internode_ll } // namespace internode_ll
} // namespace deep_ep } // namespace deep_ep
#endif
...@@ -31,6 +31,21 @@ ...@@ -31,6 +31,21 @@
} \ } \
} }
#define UNROLLED_WARP_COPY_LL(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \
constexpr int kLoopStride = kWarpSize * (UNROLL_FACTOR); \
typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)]; \
auto __src = (SRC); \
auto __dst = (DST); \
for(int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \
_Pragma("unroll") for(int __j = 0; __j < (UNROLL_FACTOR); ++__j) unrolled_values[__j] = LD_FUNC(__src + __i + __j * kWarpSize); \
_Pragma("unroll") for(int __j = 0; __j < (UNROLL_FACTOR); ++__j) ST_FUNC(__dst + __i + __j * kWarpSize, unrolled_values[__j]); \
} \
for(int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += kWarpSize) \
ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \
}
#define UNROLLED_WARP_COPY_EMULATED(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \ #define UNROLLED_WARP_COPY_EMULATED(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \ { \
constexpr int kLoopStride = kEmulatedWarpSize * (UNROLL_FACTOR); \ constexpr int kLoopStride = kEmulatedWarpSize * (UNROLL_FACTOR); \
...@@ -329,8 +344,8 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) { ...@@ -329,8 +344,8 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
#ifdef USE_ROCM #ifdef USE_ROCM
constexpr float kFP8Margin = 1e-4; constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 240.0f; constexpr float kFinfoAmaxE4M3 = 240.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3; constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
#else #else
constexpr float kFP8Margin = 1e-4; constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 448.0f; constexpr float kFinfoAmaxE4M3 = 448.0f;
...@@ -350,8 +365,9 @@ __forceinline__ __device__ int fast_log2_ceil(float x) { ...@@ -350,8 +365,9 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return exp_x - 127 + (man_bits != 0); return exp_x - 127 + (man_bits != 0);
} }
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) { template <bool kRoundScale>
if (round_scale) { __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv) {
if constexpr(kRoundScale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3); auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv); scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv); scale_inv = fast_pow2(exp_scale_inv);
......
...@@ -802,26 +802,17 @@ class Buffer: ...@@ -802,26 +802,17 @@ class Buffer:
self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts) self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts)
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_dispatch( def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
self, num_max_dispatch_tokens_per_rank: int, num_experts: int,
x: torch.Tensor, use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \
topk_idx: torch.Tensor, Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
num_max_dispatch_tokens_per_rank: int,
num_experts: int,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
use_fp8: bool = True,
round_scale: bool = False,
use_ue8m0: bool = False,
async_finish: bool = False,
return_recv_hook: bool = False,
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
""" """
A low-latency implementation for dispatching with IBGDA. A low-latency implementation for dispatching with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled). (specifically, IBGDA must be enabled).
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2 Even for ranks in the same node, NVLink are fully disabled for simplicity.
low-latency kernels' result tensors at a single moment. Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Arguments: Arguments:
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
...@@ -830,105 +821,52 @@ class Buffer: ...@@ -830,105 +821,52 @@ class Buffer:
are supported. `-1` indices (not selecting any expert) are supported. are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts. num_experts: the number of all experts.
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
monitoring.
dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors. use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you do not set this flag, the kernel will ensure the data's arrival. If you not set this flag, the kernel will ensure the data's arrival.
Returns: Returns:
recv_x: a tensor or tuple with received tokens for each expert. recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`. `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`, `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility. Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`. `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are, Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced). as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
expert receives. As mentioned before, not all tokens are valid in `recv_x`. expert receive. As mentioned before, all not tokens are valid in `recv_x`.
handle: the communication handle to be used in the `low_latency_combine` function. handle: the communication handle to be used in the `low_latency_combine` function.
event: the event after executing the kernel (valid only if `async_finish` is set). event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set).
""" """
( packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
packed_recv_x, self.runtime.low_latency_dispatch(x, topk_idx,
packed_recv_x_scales, num_max_dispatch_tokens_per_rank, num_experts,
packed_recv_count, use_fp8, async_finish, return_recv_hook)
packed_recv_src_info, handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
packed_recv_layout_range, tensors_to_record = (x, topk_idx,
event, packed_recv_x, packed_recv_x_scales, packed_recv_count,
hook, packed_recv_src_info, packed_recv_layout_range)
) = self.runtime.low_latency_dispatch( return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \
x, EventOverlap(event, tensors_to_record if async_finish else None), hook
topk_idx,
cumulative_local_expert_recv_stats,
dispatch_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank,
num_experts,
use_fp8,
round_scale,
use_ue8m0,
async_finish,
return_recv_hook,
)
handle = (
packed_recv_src_info,
packed_recv_layout_range,
num_max_dispatch_tokens_per_rank,
x.size(1),
num_experts,
)
tensors_to_record = (
x,
topk_idx,
packed_recv_x,
packed_recv_x_scales,
packed_recv_count,
packed_recv_src_info,
packed_recv_layout_range,
cumulative_local_expert_recv_stats,
)
return (
(packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x,
packed_recv_count,
handle,
EventOverlap(event, tensors_to_record if async_finish else None),
hook,
)
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_combine( def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
self, handle: tuple, zero_copy: bool = False, async_finish: bool = False,
x: torch.Tensor, return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \
topk_idx: torch.Tensor, Tuple[torch.Tensor, EventOverlap, Callable]:
topk_weights: torch.Tensor,
handle: tuple,
use_logfmt: bool = False,
zero_copy: bool = False,
async_finish: bool = False,
return_recv_hook: bool = False,
out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, EventOverlap, Callable]:
""" """
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled). (specifically, IBGDA must be enabled).
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2 Even for ranks in the same node, NVLink are fully disabled for simplicity.
low-latency kernels' result tensors at a single moment. Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Arguments: Arguments:
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`, x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
...@@ -939,39 +877,23 @@ class Buffer: ...@@ -939,39 +877,23 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor. tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function. handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`. with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you do not set this flag, the kernel will ensure the data's arrival. If you not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
Returns: Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `torch.bfloat16`. combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
event: the event after executing the kernel (valid only if `async_finish` is set). event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set).
""" """
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine( combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
x, num_max_dispatch_tokens_per_rank, num_experts,
topk_idx, zero_copy, async_finish, return_recv_hook, out)
topk_weights,
src_info,
layout_range,
combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank,
num_experts,
use_logfmt,
zero_copy,
async_finish,
return_recv_hook,
out,
)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
...@@ -988,34 +910,4 @@ class Buffer: ...@@ -988,34 +910,4 @@ class Buffer:
by yourself. by yourself.
""" """
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
return self.runtime.get_next_low_latency_combine_buffer( return self.runtime.get_next_low_latency_combine_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts)
num_max_dispatch_tokens_per_rank, hidden, num_experts
)
def low_latency_update_mask_buffer(self, rank_to_mask: int, mask: bool = False):
"""
Mask (unmask) a rank during communication (dispatch, combine, and clean)
Arguments:
rank: the rank to mask (unmask).
mask: if True, will mask the rank (do not recvfrom/sendto the rank), otherwise will unmask the rank.
"""
self.runtime.low_latency_update_mask_buffer(rank_to_mask, mask)
def low_latency_query_mask_buffer(self, mask_status: torch.Tensor):
"""
Query the mask status of all ranks
Arguments:
mask_status: `[num_ranks]` with `torch.int`, the mask status of each rank. `1` means mask and `0` means unmasked.
"""
self.runtime.low_latency_query_mask_buffer(mask_status)
def low_latency_clean_mask_buffer(self):
"""
Clean the mask buffer
"""
self.runtime.low_latency_clean_mask_buffer()
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#define LIBRARY_INCLUDE_ROCSHMEM_HPP #define LIBRARY_INCLUDE_ROCSHMEM_HPP
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <mpi.h>
#include "rocshmem_config.h" #include "rocshmem_config.h"
#include "rocshmem_common.hpp" #include "rocshmem_common.hpp"
...@@ -36,6 +35,10 @@ ...@@ -36,6 +35,10 @@
#include "rocshmem_COLL.hpp" #include "rocshmem_COLL.hpp"
#include "rocshmem_P2P_SYNC.hpp" #include "rocshmem_P2P_SYNC.hpp"
#include "rocshmem_RMA_X.hpp" #include "rocshmem_RMA_X.hpp"
#if defined(HAVE_EXTERNAL_MPI)
#include <mpi.h>
#endif
/** /**
* @file rocshmem.hpp * @file rocshmem.hpp
* @brief Public header for rocSHMEM device and host libraries. * @brief Public header for rocSHMEM device and host libraries.
...@@ -57,13 +60,22 @@ constexpr char VERSION[] = "3.0.0"; ...@@ -57,13 +60,22 @@ constexpr char VERSION[] = "3.0.0";
/****************************************************************************** /******************************************************************************
**************************** HOST INTERFACE ********************************** **************************** HOST INTERFACE **********************************
*****************************************************************************/ *****************************************************************************/
#if defined(HAVE_EXTERNAL_MPI)
/** /**
* @brief Initialize the rocSHMEM runtime and underlying transport layer. * @brief Initialize the rocSHMEM runtime and underlying transport layer.
* *
* @param[in] comm (Optional) MPI Communicator that rocSHMEM will be using * @param[in] comm MPI Communicator that rocSHMEM will be using
* If MPI_COMM_NULL, rocSHMEM will be using MPI_COMM_WORLD * If MPI_COMM_NULL, rocSHMEM will be using MPI_COMM_WORLD
*/ */
__host__ void rocshmem_init(MPI_Comm comm = MPI_COMM_WORLD); [[deprecated]] __host__ void rocshmem_init(MPI_Comm comm);
#endif
/**
* @brief Initialize the rocSHMEM runtime and underlying transport layer.
* This is equivalent to the previous function, using implicitely
* MPI_COMM_WORLD for initialization
*/
__host__ void rocshmem_init(void);
/** /**
* @brief Query rocSHMEM context from host API * @brief Query rocSHMEM context from host API
...@@ -86,8 +98,10 @@ __host__ void * rocshmem_get_device_ctx(); ...@@ -86,8 +98,10 @@ __host__ void * rocshmem_get_device_ctx();
* This can be used to issue load/store from custom kernels * This can be used to issue load/store from custom kernels
* instead of using rocshmem device side get/put APIs for RMA operations. * instead of using rocshmem device side get/put APIs for RMA operations.
*/ */
__host__ void *rocshmem_ptr(void *dest, int pe); __host__ void* rocshmem_ptr(const void *dest, int pe);
__device__ ATTR_NO_INLINE void* rocshmem_ptr(const void *dest, int pe);
#if defined(HAVE_EXTERNAL_MPI)
/** /**
* @brief Initialize the rocSHMEM runtime and underlying transport layer * @brief Initialize the rocSHMEM runtime and underlying transport layer
* with an attempt to enable the requested thread support. * with an attempt to enable the requested thread support.
...@@ -102,8 +116,9 @@ __host__ void *rocshmem_ptr(void *dest, int pe); ...@@ -102,8 +116,9 @@ __host__ void *rocshmem_ptr(void *dest, int pe);
* @return int returns 0 upon success; otherwise, it returns a nonzero * @return int returns 0 upon success; otherwise, it returns a nonzero
* value * value
*/ */
__host__ int rocshmem_init_thread(int requested, int *provided, [[deprecated]] __host__ int rocshmem_init_thread(int requested, int *provided,
MPI_Comm comm = MPI_COMM_WORLD); MPI_Comm comm);
#endif
/** /**
* @brief Initialize the rocSHMEM runtime and underlying transport layer * @brief Initialize the rocSHMEM runtime and underlying transport layer
...@@ -327,6 +342,13 @@ __host__ void rocshmem_quiet(); ...@@ -327,6 +342,13 @@ __host__ void rocshmem_quiet();
*/ */
__host__ void rocshmem_barrier_all(); __host__ void rocshmem_barrier_all();
/**
* @brief enqueues a collective barrier on given stream.
*
* @return void
*/
__host__ void rocshmem_barrier_all_on_stream(hipStream_t stream);
/** /**
* @brief registers the arrival of a PE at a barrier. * @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved. * The caller is blocked until the synchronization is resolved.
...@@ -360,7 +382,7 @@ __host__ void rocshmem_global_exit(int status); ...@@ -360,7 +382,7 @@ __host__ void rocshmem_global_exit(int status);
* *
* @return void. * @return void.
*/ */
__device__ void rocshmem_wg_init(); [[deprecated]] __device__ void rocshmem_wg_init();
/** /**
* @brief Finalizes device-side rocSHMEM resources. Must be called before * @brief Finalizes device-side rocSHMEM resources. Must be called before
...@@ -370,7 +392,7 @@ __device__ void rocshmem_wg_init(); ...@@ -370,7 +392,7 @@ __device__ void rocshmem_wg_init();
* *
* @return void. * @return void.
*/ */
__device__ void rocshmem_wg_finalize(); [[deprecated]] __device__ void rocshmem_wg_finalize();
/** /**
* @brief Initializes device-side rocSHMEM resources. Must be called before * @brief Initializes device-side rocSHMEM resources. Must be called before
...@@ -386,7 +408,7 @@ __device__ void rocshmem_wg_finalize(); ...@@ -386,7 +408,7 @@ __device__ void rocshmem_wg_finalize();
* *
* @return void. * @return void.
*/ */
__device__ void rocshmem_wg_init_thread(int requested, int *provided); [[deprecated]] __device__ void rocshmem_wg_init_thread(int requested, int *provided);
/** /**
* @brief Query the thread mode used by the runtime. * @brief Query the thread mode used by the runtime.
...@@ -476,6 +498,23 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_quiet(rocshmem_ctx_t ctx); ...@@ -476,6 +498,23 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_quiet(rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_quiet(); __device__ ATTR_NO_INLINE void rocshmem_quiet();
/**
* @brief Completes all previous operations posted to this context for PEs in the
* `target_pes` array.
*
* @param[in] ctx Context with which to perform this operation.
*
* @param[in] target_pes Address of target PE array where the operations need to be completed.
*
* @param[in] npes The number of PEs in the target PE array.
*
* @return void.
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_pe_quiet(rocshmem_ctx_t ctx, const int *target_pes, size_t npes);
__device__ ATTR_NO_INLINE void rocshmem_pe_quiet(const int *target_pes, size_t npes);
/** /**
* @brief Query the total number of PEs. * @brief Query the total number of PEs.
* *
......
...@@ -599,6 +599,14 @@ __host__ int rocshmem_ctx_double_prod_reduce( ...@@ -599,6 +599,14 @@ __host__ int rocshmem_ctx_double_prod_reduce(
rocshmem_ctx_t ctx, rocshmem_team_t team, double *dest, const double *source, rocshmem_ctx_t ctx, rocshmem_team_t team, double *dest, const double *source,
int nreduce); int nreduce);
/**
* @brief kernel for performing a barrier synchronization.
* Caller enqueues the kernel on given stream
*
* @return void
*/
__global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel();
/** /**
* @brief perform a collective barrier between all PEs in the system. * @brief perform a collective barrier between all PEs in the system.
* The caller is blocked until the barrier is resolved. * The caller is blocked until the barrier is resolved.
...@@ -767,28 +775,6 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_sync_wave( ...@@ -767,28 +775,6 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_sync_wave(
__device__ ATTR_NO_INLINE void rocshmem_ctx_sync_wg( __device__ ATTR_NO_INLINE void rocshmem_ctx_sync_wg(
rocshmem_ctx_t ctx, rocshmem_team_t team); rocshmem_ctx_t ctx, rocshmem_team_t team);
/**
* @brief Query a local pointer to a symmetric data object on the
* specified \pe . Returns an address that may be used to directly reference
* dest on the specified \pe. This address can be accesses with LD/ST ops.
*
* Can be called per thread with no performance penalty.
*/
__device__ ATTR_NO_INLINE void *rocshmem_ptr(const void *dest, int pe);
/**
* @brief Make all uncacheable GPU data visible to other agents in the sytem.
*
* This only works for data that was explicitly allocated uncacheable on the
* GPU!
*
* Can be called per thread with no performance penalty.
*
* @param[in] GPU-side handle.
*
* @return void
*/
} // namespace rocshmem } // namespace rocshmem
#endif // LIBRARY_INCLUDE_ROCSHMEM_COLL_HPP #endif // LIBRARY_INCLUDE_ROCSHMEM_COLL_HPP
...@@ -106,9 +106,18 @@ const int ROCSHMEM_CTX_SHARED = 8; ...@@ -106,9 +106,18 @@ const int ROCSHMEM_CTX_SHARED = 8;
* @brief GPU side OpenSHMEM context created from each work-groups' * @brief GPU side OpenSHMEM context created from each work-groups'
* rocshmem_wg_handle_t * rocshmem_wg_handle_t
*/ */
typedef struct { typedef struct rocshmem_ctx{
void *ctx_opaque; void *ctx_opaque;
void *team_opaque; void *team_opaque;
__host__ __device__ bool operator==(const struct rocshmem_ctx& other) const {
return (ctx_opaque == other.ctx_opaque &&
team_opaque == other.team_opaque);
}
__host__ __device__ bool operator!=(const struct rocshmem_ctx& other) const {
return !(*this == other);
}
} rocshmem_ctx_t; } rocshmem_ctx_t;
/** /**
...@@ -116,6 +125,14 @@ typedef struct { ...@@ -116,6 +125,14 @@ typedef struct {
*/ */
extern "C" __device__ rocshmem_ctx_t __attribute__((visibility("default"))) ROCSHMEM_CTX_DEFAULT; extern "C" __device__ rocshmem_ctx_t __attribute__((visibility("default"))) ROCSHMEM_CTX_DEFAULT;
/**
* A value corresponding to an invalid communication context. This value can be
* used to initialize or update context handles to indicate that they do not
* reference a valid context. When managed in this way, applications can use an
* equality comparison to test whether a given context handle references a
* valid context.
*/
extern __constant__ rocshmem_ctx_t ROCSHMEM_CTX_INVALID;
/** /**
* Used internally to set default context. * Used internally to set default context.
*/ */
......
...@@ -45,3 +45,4 @@ ...@@ -45,3 +45,4 @@
/* #undef GDA_IONIC */ /* #undef GDA_IONIC */
/* #undef GDA_BNXT */ /* #undef GDA_BNXT */
#define GDA_MLX5 #define GDA_MLX5
#define HAVE_EXTERNAL_MPI
/******************************************************************************
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#ifndef LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
#define LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
#if defined(HAVE_EXTERNAL_MPI)
#include <mpi.h>
#endif
#if defined(c_plusplus) || defined(__cplusplus)
extern "C" {
#endif
#if !defined(MPI_VERSION)
// Open MPI based values for the constants/handles etc.
// Even though we did not include an external MPI header file
// The includer may have (e.g., a unit test).
typedef void* MPI_Comm;
typedef void* MPI_Win;
typedef void* MPI_Group;
typedef void* MPI_Op;
typedef void* MPI_Datatype;
typedef void* MPI_Request;
typedef void* MPI_Info;
struct ompi_status_public_t {
int MPI_SOURCE;
int MPI_TAG;
int MPI_ERROR;
int _cancelled;
size_t _ucount;
};
typedef struct ompi_status_public_t MPI_Status;
#define MPI_Aint uint64_t
#define MPI_UNDEFINED -32766
#define MPI_THREAD_MULTIPLE 3
#define MPI_SUCCESS 0
#define MPI_IN_PLACE (void*)1
#define MPI_MODE_NOCHECK 1
#define MPI_COMM_TYPE_SHARED 0
#define MPI_Aint_diff(addr1, addr2) ((MPI_Aint) ((char *) (addr1) - (char *) (addr2)))
struct ompi_internal_symbols_t {
void *ompi_mpi_comm_world;
void *ompi_mpi_comm_null;
void *ompi_request_null;
void *ompi_mpi_info_null;
void *ompi_mpi_datatype_null;
void *ompi_mpi_op_max;
void *ompi_mpi_op_min;
void *ompi_mpi_op_sum;
void *ompi_mpi_op_prod;
void *ompi_mpi_op_band;
void *ompi_mpi_op_bor;
void *ompi_mpi_op_bxor;
void *ompi_mpi_op_replace;
void *ompi_mpi_op_no_op;
void *ompi_mpi_char;
void *ompi_mpi_unsigned_char;
void *ompi_mpi_signed_char;
void *ompi_mpi_short;
void *ompi_mpi_unsigned_short;
void *ompi_mpi_int;
void *ompi_mpi_unsigned;
void *ompi_mpi_long;
void *ompi_mpi_unsigned_long;
void *ompi_mpi_long_long_int;
void *ompi_mpi_unsigned_long_long;
void *ompi_mpi_float;
void *ompi_mpi_double;
void *ompi_mpi_long_double;
};
extern struct ompi_internal_symbols_t ompi_symbols_;
#define OMPI_PREDEFINED_GLOBAL(type, global) (static_cast<type> (global))
#define MPI_COMM_WORLD OMPI_PREDEFINED_GLOBAL(MPI_Comm, ompi_symbols_.ompi_mpi_comm_world)
#define MPI_COMM_NULL OMPI_PREDEFINED_GLOBAL(MPI_Comm, ompi_symbols_.ompi_mpi_comm_null)
#define MPI_REQUEST_NULL OMPI_PREDEFINED_GLOBAL(MPI_Request, ompi_symbols_.ompi_request_null)
#define MPI_WIN_NULL OMPI_PREDEFINED_GLOBAL(MPI_Win, ompi_symbols_.ompi_mpi_win_null)
#define MPI_INFO_NULL OMPI_PREDEFINED_GLOBAL(MPI_Info, ompi_symbols_.ompi_mpi_info_null)
#define MPI_MAX OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_max)
#define MPI_MIN OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_min)
#define MPI_SUM OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_sum)
#define MPI_PROD OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_prod)
#define MPI_BAND OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_band)
#define MPI_BOR OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_bor)
#define MPI_BXOR OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_bxor)
#define MPI_REPLACE OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_replace)
#define MPI_NO_OP OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_no_op)
#define MPI_DATATYPE_NULL OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_datatype_null)
#define MPI_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_char)
#define MPI_UNSIGNED_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_char)
#define MPI_SIGNED_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_signed_char)
#define MPI_SHORT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_short)
#define MPI_UNSIGNED_SHORT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_short)
#define MPI_INT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_int)
#define MPI_UNSIGNED OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned)
#define MPI_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long)
#define MPI_UNSIGNED_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_long)
#define MPI_LONG_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long_long_int)
#define MPI_UNSIGNED_LONG_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_long_long)
#define MPI_FLOAT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_float)
#define MPI_DOUBLE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_double)
#define MPI_LONG_DOUBLE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long_double)
#endif //!defined(MPI_VERSION)
#if defined(c_plusplus) || defined(__cplusplus)
}
#endif
#endif //LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
...@@ -61,7 +61,7 @@ add_library(roc::rocshmem STATIC IMPORTED) ...@@ -61,7 +61,7 @@ add_library(roc::rocshmem STATIC IMPORTED)
set_target_properties(roc::rocshmem PROPERTIES set_target_properties(roc::rocshmem PROPERTIES
INTERFACE_COMPILE_OPTIONS "-fgpu-rdc;-fgpu-rdc" INTERFACE_COMPILE_OPTIONS "-fgpu-rdc;-fgpu-rdc"
INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include;${_IMPORT_PREFIX}/include" INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include;${_IMPORT_PREFIX}/include"
INTERFACE_LINK_LIBRARIES "IBVerbs::verbs;numa;Threads::Threads;MPI::MPI_CXX;hip::device;hip::host;hsa-runtime64::hsa-runtime64;-fgpu-rdc" INTERFACE_LINK_LIBRARIES "IBVerbs::verbs;numa;\$<\$<BOOL:ON>:MPI::MPI_CXX>;Threads::Threads;hip::device;hip::host;dl;hsa-runtime64::hsa-runtime64;-fgpu-rdc"
) )
# Load information for each installed configuration. # Load information for each installed configuration.
......
../../../lib/cmake/rocshmem/rocshmem-config-version.cmake # This is a basic version file for the Config-mode of find_package().
\ No newline at end of file # It is used by write_basic_package_version_file() as input file for configure_file()
# to create a version-file which can be installed along a config.cmake file.
#
# The created file sets PACKAGE_VERSION_EXACT if the current version string and
# the requested version string are exactly the same and it sets
# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version,
# but only if the requested major version is the same as the current one.
# The variable CVF_VERSION must be set before calling configure_file().
set(PACKAGE_VERSION "3.0.0")
if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION)
set(PACKAGE_VERSION_COMPATIBLE FALSE)
else()
if("3.0.0" MATCHES "^([0-9]+)\\.")
set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}")
if(NOT CVF_VERSION_MAJOR VERSION_EQUAL 0)
string(REGEX REPLACE "^0+" "" CVF_VERSION_MAJOR "${CVF_VERSION_MAJOR}")
endif()
else()
set(CVF_VERSION_MAJOR "3.0.0")
endif()
if(PACKAGE_FIND_VERSION_RANGE)
# both endpoints of the range must have the expected major version
math (EXPR CVF_VERSION_MAJOR_NEXT "${CVF_VERSION_MAJOR} + 1")
if (NOT PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
OR ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX_MAJOR STREQUAL CVF_VERSION_MAJOR)
OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX VERSION_LESS_EQUAL CVF_VERSION_MAJOR_NEXT)))
set(PACKAGE_VERSION_COMPATIBLE FALSE)
elseif(PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
AND ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_LESS_EQUAL PACKAGE_FIND_VERSION_MAX)
OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MAX)))
set(PACKAGE_VERSION_COMPATIBLE TRUE)
else()
set(PACKAGE_VERSION_COMPATIBLE FALSE)
endif()
else()
if(PACKAGE_FIND_VERSION_MAJOR STREQUAL CVF_VERSION_MAJOR)
set(PACKAGE_VERSION_COMPATIBLE TRUE)
else()
set(PACKAGE_VERSION_COMPATIBLE FALSE)
endif()
if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION)
set(PACKAGE_VERSION_EXACT TRUE)
endif()
endif()
endif()
# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it:
if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "")
return()
endif()
# check that the installed version has the same 32/64bit-ness as the one which is currently searching:
if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8")
math(EXPR installedBits "8 * 8")
set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)")
set(PACKAGE_VERSION_UNSUITABLE TRUE)
endif()
../../../lib/cmake/rocshmem/rocshmem-config.cmake
\ No newline at end of file
####################################################################################
# Auto generated @PACKAGE_INIT@ by rocm_configure_package_config_file()
# from rocshmem-config.cmake.in
# Any changes to this file will be overwritten by the next CMake run
####################################################################################
get_filename_component(_ROCM_CMAKE_CURRENT_LIST_FILE_REAL "${CMAKE_CURRENT_LIST_FILE}" REALPATH)
get_filename_component(_ROCM_CMAKE_CURRENT_LIST_DIR_REAL "${_ROCM_CMAKE_CURRENT_LIST_FILE_REAL}" DIRECTORY)
get_filename_component(PACKAGE_PREFIX_DIR "${_ROCM_CMAKE_CURRENT_LIST_DIR_REAL}/../../../" ABSOLUTE)
macro(set_and_check _var _file)
set(${_var} "${_file}")
if(NOT EXISTS "${_file}")
message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !")
endif()
endmacro()
include(CMakeFindDependencyMacro OPTIONAL RESULT_VARIABLE _ROCMCMakeFindDependencyMacro_FOUND)
if (NOT _ROCMCMakeFindDependencyMacro_FOUND)
macro(find_dependency dep)
if (NOT ${dep}_FOUND)
set(rocm_fd_version)
if (${ARGC} GREATER 1)
set(rocm_fd_version ${ARGV1})
endif()
set(rocm_fd_exact_arg)
if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION_EXACT)
set(rocm_fd_exact_arg EXACT)
endif()
set(rocm_fd_quiet_arg)
if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
set(rocm_fd_quiet_arg QUIET)
endif()
set(rocm_fd_required_arg)
if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
set(rocm_fd_required_arg REQUIRED)
endif()
find_package(${dep} ${rocm_fd_version}
${rocm_fd_exact_arg}
${rocm_fd_quiet_arg}
${rocm_fd_required_arg}
)
string(TOUPPER ${dep} cmake_dep_upper)
if (NOT ${dep}_FOUND AND NOT ${cmake_dep_upper}_FOUND)
set(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE
"${CMAKE_FIND_PACKAGE_NAME} could not be found because dependency ${dep} could not be found.")
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND False)
return()
endif()
set(rocm_fd_version)
set(rocm_fd_required_arg)
set(rocm_fd_quiet_arg)
set(rocm_fd_exact_arg)
endif()
endmacro()
endif()
macro(check_required_components _NAME)
foreach(comp ${${_NAME}_FIND_COMPONENTS})
if(NOT ${_NAME}_${comp}_FOUND)
if(${_NAME}_FIND_REQUIRED_${comp})
set(${_NAME}_FOUND FALSE)
endif()
endif()
endforeach()
endmacro()
####################################################################################
set_and_check(rocshmem_INCLUDE_DIR ${PACKAGE_PREFIX_DIR}/include)
set_and_check(rocshmem_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
set_and_check(ROCSHMEM_INCLUDE_DIR ${PACKAGE_PREFIX_DIR}/include)
set_and_check(ROCSHMEM_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
set_and_check(rocshmem_INCLUDE_DIR ${PACKAGE_PREFIX_DIR}/include)
set_and_check(rocshmem_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
set_and_check(rocshmem_TARGET_FILE ${PACKAGE_PREFIX_DIR}/lib/cmake/rocshmem/rocshmem-targets.cmake)
include(${rocshmem_TARGET_FILE})
set(rocshmem_LIBRARIES roc::rocshmem)
set(rocshmem_LIBRARY roc::rocshmem)
set(ROCSHMEM_LIBRARIES roc::rocshmem)
set(ROCSHMEM_LIBRARY roc::rocshmem)
set(rocshmem_LIBRARIES roc::rocshmem)
set(rocshmem_LIBRARY roc::rocshmem)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment