Commit da13c63a authored by lishen's avatar lishen
Browse files

完成低延迟接口功能

parent 09cb2b03
pgrep -f /usr/bin/python | xargs kill -9
export OMPI_MCA_pml=ucx export OMPI_MCA_pml=ucx
export OMPI_MCA_osc=ucx export OMPI_MCA_osc=ucx
export OMPI_MCA_coll_hcoll_enable=0 export OMPI_MCA_coll_hcoll_enable=0
...@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa" ...@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export ROCSHMEM_MAX_NUM_CONTEXTS=32 export ROCSHMEM_MAX_NUM_CONTEXTS=32
export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384 export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240 # export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export ROCSHMEM_HEAP_SIZE=536870912 export ROCSHMEM_HEAP_SIZE=2880100992
export PYTHONPATH=/public/home/lishen/Tmp/DeepEP:$PYTHONPATH # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
pgrep -f /usr/bin/python | xargs kill -9
export OMPI_MCA_pml=ucx export OMPI_MCA_pml=ucx
export OMPI_MCA_osc=ucx export OMPI_MCA_osc=ucx
export OMPI_MCA_coll_hcoll_enable=0 export OMPI_MCA_coll_hcoll_enable=0
...@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa" ...@@ -7,8 +9,9 @@ export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export ROCSHMEM_MAX_NUM_CONTEXTS=32 export ROCSHMEM_MAX_NUM_CONTEXTS=32
export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384 export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240 # export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export ROCSHMEM_HEAP_SIZE=536870912 export ROCSHMEM_HEAP_SIZE=2880100992
export PYTHONPATH=/public/home/lishen/Tmp/DeepEP:$PYTHONPATH # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
...@@ -8,8 +8,10 @@ fi ...@@ -8,8 +8,10 @@ fi
PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])") PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])")
PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])") PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])")
INCLUDE_PATHS=${INCLUDE_PATHS:=-Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE}} ROCSHMEM_INSTALL_PREFIX=${ROCSHMEM_INSTALL_PREFIX:=$(pwd)/rocshmem_dir}
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17}
INCLUDE_PATHS=${INCLUDE_PATHS:=-Icsrc/ -I${ROCSHMEM_INSTALL_PREFIX}/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE}}
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 -Wno-return-type}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/runtime.cu -o build_/runtime.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/runtime.cu -o build_/runtime.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/layout.cu -o build_/layout.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/layout.cu -o build_/layout.o ${COMPILE_OPTIONS}
...@@ -18,7 +20,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o ...@@ -18,7 +20,7 @@ hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode_ll.cu -o build_/internode_ll.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode_ll.cu -o build_/internode_ll.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/deep_ep.cu -o build_/deep_ep.o ${COMPILE_OPTIONS} hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/deep_ep.cu -o build_/deep_ep.o ${COMPILE_OPTIONS}
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L$(pwd)/rocshmem_dir/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,$(pwd)/rocshmem_dir/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5 hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L${ROCSHMEM_INSTALL_PREFIX}/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,$${ROCSHMEM_INSTALL_PREFIX}/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5
# build whl # build whl
echo "Using Python: $(which python3)" echo "Using Python: $(which python3)"
......
...@@ -136,7 +136,7 @@ struct LowLatencyLayout { ...@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) { int num_ranks, int num_experts) {
const int num_scales = hidden / 128; const int num_scales = hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL;
// Dispatch and combine layout: // Dispatch and combine layout:
// - 2 symmetric odd/even send buffer // - 2 symmetric odd/even send buffer
......
This diff is collapsed.
...@@ -26,14 +26,17 @@ private: ...@@ -26,14 +26,17 @@ private:
void *buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; void *buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void **buffer_ptrs_gpu = nullptr; void **buffer_ptrs_gpu = nullptr;
void* nvl_buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void** nvl_buffer_ptrs_gpu = nullptr;
// NVSHMEM Buffer // NVSHMEM Buffer
int64_t num_rdma_bytes; int64_t num_rdma_bytes;
void *rdma_buffer_ptr = nullptr; void *rdma_buffer_ptr = nullptr;
// Shrink mode buffer // Shrink mode buffer
bool enable_shrink = false; bool enable_shrink = false;
int* mask_buffer_ptr = nullptr; int* mask_buffer_ptr = nullptr;
int* sync_buffer_ptr = nullptr; int* sync_buffer_ptr = nullptr;
// Device info and communication // Device info and communication
int device_id; int device_id;
...@@ -171,31 +174,19 @@ public: ...@@ -171,31 +174,19 @@ public:
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden,
int num_experts); int num_experts);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_idx, int num_max_dispatch_tokens_per_rank, int num_experts,
const std::optional<torch::Tensor> &cumulative_local_expert_recv_stats, bool use_fp8, bool async, bool return_recv_hook);
const std::optional<torch::Tensor> &dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8,
bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor &x, const torch::Tensor &topk_idx, low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor &topk_weights, const torch::Tensor &src_info, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const torch::Tensor &layout_range, int num_max_dispatch_tokens_per_rank, int num_experts,
const std::optional<torch::Tensor> &combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor> &out = std::nullopt); const std::optional<torch::Tensor>& out = std::nullopt);
torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank,
int hidden, int num_experts) const;
void low_latency_update_mask_buffer(int rank_to_mask, bool mask);
void low_latency_query_mask_buffer(const torch::Tensor& mask_status);
void low_latency_clean_mask_buffer(); torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
}; };
} // namespace deep_ep } // namespace deep_ep
...@@ -134,42 +134,31 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights, ...@@ -134,42 +134,31 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
// Internode low-latency kernels // Internode low-latency kernels
namespace internode_ll { namespace internode_ll {
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1, int64_t* clean_1, int num_clean_int_1,
int rank, int num_ranks, hipStream_t stream);
int* mask_buffer, int* sync_buffer, hipStream_t stream);
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* packed_recv_count,
int* global_atomic_counter, int* global_atomic_counter,
int* mask_buffer, int* cumulative_local_expert_recv_stats, void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
int64_t* dispatch_wait_recv_cost_stats, const void* x, const int64_t* topk_idx,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x, int64_t* next_clean, int num_next_clean_int,
const void* x, const int64_t* topk_idx, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int64_t* next_clean, int num_next_clean_int, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, void* workspace, hipStream_t stream, int phases);
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int num_device_sms, hipStream_t stream, int phases);
void combine(void* combined_x, void combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights, const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range, const int* src_info, const int64_t* layout_range,
int* global_atomic_counter, int* global_atomic_counter,
int* mask_buffer, int64_t* combine_wait_recv_cost_stats, int64_t* next_clean, int num_next_clean_int,
int64_t* next_clean, int num_next_clean_int, int num_combined_tokens, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks,
int num_topk, int num_experts, int rank, int num_ranks, bool use_logfmt, void* workspace, hipStream_t stream,
void* workspace, int num_device_sms, hipStream_t stream, int phases, bool zero_copy);
int phases, bool zero_copy);
void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* output_mask_tensor, hipStream_t stream);
void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask, hipStream_t stream);
void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, hipStream_t stream);
} // namespace internode_ll } // namespace internode_ll
......
This diff is collapsed.
...@@ -31,6 +31,21 @@ ...@@ -31,6 +31,21 @@
} \ } \
} }
#define UNROLLED_WARP_COPY_LL(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \
constexpr int kLoopStride = kWarpSize * (UNROLL_FACTOR); \
typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)]; \
auto __src = (SRC); \
auto __dst = (DST); \
for(int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \
_Pragma("unroll") for(int __j = 0; __j < (UNROLL_FACTOR); ++__j) unrolled_values[__j] = LD_FUNC(__src + __i + __j * kWarpSize); \
_Pragma("unroll") for(int __j = 0; __j < (UNROLL_FACTOR); ++__j) ST_FUNC(__dst + __i + __j * kWarpSize, unrolled_values[__j]); \
} \
for(int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += kWarpSize) \
ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \
}
#define UNROLLED_WARP_COPY_EMULATED(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \ #define UNROLLED_WARP_COPY_EMULATED(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \ { \
constexpr int kLoopStride = kEmulatedWarpSize * (UNROLL_FACTOR); \ constexpr int kLoopStride = kEmulatedWarpSize * (UNROLL_FACTOR); \
...@@ -329,8 +344,8 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) { ...@@ -329,8 +344,8 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
#ifdef USE_ROCM #ifdef USE_ROCM
constexpr float kFP8Margin = 1e-4; constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 240.0f; constexpr float kFinfoAmaxE4M3 = 240.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3; constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
#else #else
constexpr float kFP8Margin = 1e-4; constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 448.0f; constexpr float kFinfoAmaxE4M3 = 448.0f;
...@@ -350,8 +365,9 @@ __forceinline__ __device__ int fast_log2_ceil(float x) { ...@@ -350,8 +365,9 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return exp_x - 127 + (man_bits != 0); return exp_x - 127 + (man_bits != 0);
} }
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) { template <bool kRoundScale>
if (round_scale) { __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv) {
if constexpr(kRoundScale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3); auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv); scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv); scale_inv = fast_pow2(exp_scale_inv);
......
...@@ -802,26 +802,17 @@ class Buffer: ...@@ -802,26 +802,17 @@ class Buffer:
self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts) self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts)
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_dispatch( def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
self, num_max_dispatch_tokens_per_rank: int, num_experts: int,
x: torch.Tensor, use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \
topk_idx: torch.Tensor, Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
num_max_dispatch_tokens_per_rank: int,
num_experts: int,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
use_fp8: bool = True,
round_scale: bool = False,
use_ue8m0: bool = False,
async_finish: bool = False,
return_recv_hook: bool = False,
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
""" """
A low-latency implementation for dispatching with IBGDA. A low-latency implementation for dispatching with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled). (specifically, IBGDA must be enabled).
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2 Even for ranks in the same node, NVLink are fully disabled for simplicity.
low-latency kernels' result tensors at a single moment. Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Arguments: Arguments:
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
...@@ -830,105 +821,52 @@ class Buffer: ...@@ -830,105 +821,52 @@ class Buffer:
are supported. `-1` indices (not selecting any expert) are supported. are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts. num_experts: the number of all experts.
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
monitoring.
dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors. use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you do not set this flag, the kernel will ensure the data's arrival. If you not set this flag, the kernel will ensure the data's arrival.
Returns: Returns:
recv_x: a tensor or tuple with received tokens for each expert. recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`. `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`, `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility. Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`. `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are, Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced). as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
expert receives. As mentioned before, not all tokens are valid in `recv_x`. expert receive. As mentioned before, all not tokens are valid in `recv_x`.
handle: the communication handle to be used in the `low_latency_combine` function. handle: the communication handle to be used in the `low_latency_combine` function.
event: the event after executing the kernel (valid only if `async_finish` is set). event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set).
""" """
( packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
packed_recv_x, self.runtime.low_latency_dispatch(x, topk_idx,
packed_recv_x_scales, num_max_dispatch_tokens_per_rank, num_experts,
packed_recv_count, use_fp8, async_finish, return_recv_hook)
packed_recv_src_info, handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
packed_recv_layout_range, tensors_to_record = (x, topk_idx,
event, packed_recv_x, packed_recv_x_scales, packed_recv_count,
hook, packed_recv_src_info, packed_recv_layout_range)
) = self.runtime.low_latency_dispatch( return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \
x, EventOverlap(event, tensors_to_record if async_finish else None), hook
topk_idx,
cumulative_local_expert_recv_stats,
dispatch_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank,
num_experts,
use_fp8,
round_scale,
use_ue8m0,
async_finish,
return_recv_hook,
)
handle = (
packed_recv_src_info,
packed_recv_layout_range,
num_max_dispatch_tokens_per_rank,
x.size(1),
num_experts,
)
tensors_to_record = (
x,
topk_idx,
packed_recv_x,
packed_recv_x_scales,
packed_recv_count,
packed_recv_src_info,
packed_recv_layout_range,
cumulative_local_expert_recv_stats,
)
return (
(packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x,
packed_recv_count,
handle,
EventOverlap(event, tensors_to_record if async_finish else None),
hook,
)
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_combine( def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
self, handle: tuple, zero_copy: bool = False, async_finish: bool = False,
x: torch.Tensor, return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \
topk_idx: torch.Tensor, Tuple[torch.Tensor, EventOverlap, Callable]:
topk_weights: torch.Tensor,
handle: tuple,
use_logfmt: bool = False,
zero_copy: bool = False,
async_finish: bool = False,
return_recv_hook: bool = False,
out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, EventOverlap, Callable]:
""" """
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled). (specifically, IBGDA must be enabled).
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2 Even for ranks in the same node, NVLink are fully disabled for simplicity.
low-latency kernels' result tensors at a single moment. Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Arguments: Arguments:
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`, x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
...@@ -939,39 +877,23 @@ class Buffer: ...@@ -939,39 +877,23 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor. tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function. handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`. with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you do not set this flag, the kernel will ensure the data's arrival. If you not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
Returns: Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `torch.bfloat16`. combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
event: the event after executing the kernel (valid only if `async_finish` is set). event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set).
""" """
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine( combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
x, num_max_dispatch_tokens_per_rank, num_experts,
topk_idx, zero_copy, async_finish, return_recv_hook, out)
topk_weights,
src_info,
layout_range,
combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank,
num_experts,
use_logfmt,
zero_copy,
async_finish,
return_recv_hook,
out,
)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
...@@ -988,34 +910,4 @@ class Buffer: ...@@ -988,34 +910,4 @@ class Buffer:
by yourself. by yourself.
""" """
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
return self.runtime.get_next_low_latency_combine_buffer( return self.runtime.get_next_low_latency_combine_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts)
num_max_dispatch_tokens_per_rank, hidden, num_experts
)
def low_latency_update_mask_buffer(self, rank_to_mask: int, mask: bool = False):
"""
Mask (unmask) a rank during communication (dispatch, combine, and clean)
Arguments:
rank: the rank to mask (unmask).
mask: if True, will mask the rank (do not recvfrom/sendto the rank), otherwise will unmask the rank.
"""
self.runtime.low_latency_update_mask_buffer(rank_to_mask, mask)
def low_latency_query_mask_buffer(self, mask_status: torch.Tensor):
"""
Query the mask status of all ranks
Arguments:
mask_status: `[num_ranks]` with `torch.int`, the mask status of each rank. `1` means mask and `0` means unmasked.
"""
self.runtime.low_latency_query_mask_buffer(mask_status)
def low_latency_clean_mask_buffer(self):
"""
Clean the mask buffer
"""
self.runtime.low_latency_clean_mask_buffer()
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#define LIBRARY_INCLUDE_ROCSHMEM_HPP #define LIBRARY_INCLUDE_ROCSHMEM_HPP
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <mpi.h>
#include "rocshmem_config.h" #include "rocshmem_config.h"
#include "rocshmem_common.hpp" #include "rocshmem_common.hpp"
...@@ -36,6 +35,10 @@ ...@@ -36,6 +35,10 @@
#include "rocshmem_COLL.hpp" #include "rocshmem_COLL.hpp"
#include "rocshmem_P2P_SYNC.hpp" #include "rocshmem_P2P_SYNC.hpp"
#include "rocshmem_RMA_X.hpp" #include "rocshmem_RMA_X.hpp"
#if defined(HAVE_EXTERNAL_MPI)
#include <mpi.h>
#endif
/** /**
* @file rocshmem.hpp * @file rocshmem.hpp
* @brief Public header for rocSHMEM device and host libraries. * @brief Public header for rocSHMEM device and host libraries.
...@@ -57,20 +60,29 @@ constexpr char VERSION[] = "3.0.0"; ...@@ -57,20 +60,29 @@ constexpr char VERSION[] = "3.0.0";
/****************************************************************************** /******************************************************************************
**************************** HOST INTERFACE ********************************** **************************** HOST INTERFACE **********************************
*****************************************************************************/ *****************************************************************************/
#if defined(HAVE_EXTERNAL_MPI)
/** /**
* @brief Initialize the rocSHMEM runtime and underlying transport layer. * @brief Initialize the rocSHMEM runtime and underlying transport layer.
* *
* @param[in] comm (Optional) MPI Communicator that rocSHMEM will be using * @param[in] comm MPI Communicator that rocSHMEM will be using
* If MPI_COMM_NULL, rocSHMEM will be using MPI_COMM_WORLD * If MPI_COMM_NULL, rocSHMEM will be using MPI_COMM_WORLD
*/ */
__host__ void rocshmem_init(MPI_Comm comm = MPI_COMM_WORLD); [[deprecated]] __host__ void rocshmem_init(MPI_Comm comm);
#endif
/**
* @brief Initialize the rocSHMEM runtime and underlying transport layer.
* This is equivalent to the previous function, using implicitely
* MPI_COMM_WORLD for initialization
*/
__host__ void rocshmem_init(void);
/** /**
* @brief Query rocSHMEM context from host API * @brief Query rocSHMEM context from host API
* *
* @param[out] ctx Returns ROCSHMEM_CTX_DEFAULT device pointer that users * @param[out] ctx Returns ROCSHMEM_CTX_DEFAULT device pointer that users
* can query from one instance of rocshmem host library and * can query from one instance of rocshmem host library and
* use use later for dynamic module initialization in * use use later for dynamic module initialization in
* kernel bitcode device library in the same application * kernel bitcode device library in the same application
*/ */
__host__ void * rocshmem_get_device_ctx(); __host__ void * rocshmem_get_device_ctx();
...@@ -79,15 +91,17 @@ __host__ void * rocshmem_get_device_ctx(); ...@@ -79,15 +91,17 @@ __host__ void * rocshmem_get_device_ctx();
* @brief Query rocSHMEM remote symmetric heap pointer * @brief Query rocSHMEM remote symmetric heap pointer
* *
* @param[in] dest local symmetric heap allocation pointer for current pe/device * @param[in] dest local symmetric heap allocation pointer for current pe/device
* *
* @param[in] pe remote PE * @param[in] pe remote PE
* *
* @param[out] ptr Returns remote symmetric heap device pointer from host-side API. * @param[out] ptr Returns remote symmetric heap device pointer from host-side API.
* This can be used to issue load/store from custom kernels * This can be used to issue load/store from custom kernels
* instead of using rocshmem device side get/put APIs for RMA operations. * instead of using rocshmem device side get/put APIs for RMA operations.
*/ */
__host__ void *rocshmem_ptr(void *dest, int pe); __host__ void* rocshmem_ptr(const void *dest, int pe);
__device__ ATTR_NO_INLINE void* rocshmem_ptr(const void *dest, int pe);
#if defined(HAVE_EXTERNAL_MPI)
/** /**
* @brief Initialize the rocSHMEM runtime and underlying transport layer * @brief Initialize the rocSHMEM runtime and underlying transport layer
* with an attempt to enable the requested thread support. * with an attempt to enable the requested thread support.
...@@ -102,8 +116,9 @@ __host__ void *rocshmem_ptr(void *dest, int pe); ...@@ -102,8 +116,9 @@ __host__ void *rocshmem_ptr(void *dest, int pe);
* @return int returns 0 upon success; otherwise, it returns a nonzero * @return int returns 0 upon success; otherwise, it returns a nonzero
* value * value
*/ */
__host__ int rocshmem_init_thread(int requested, int *provided, [[deprecated]] __host__ int rocshmem_init_thread(int requested, int *provided,
MPI_Comm comm = MPI_COMM_WORLD); MPI_Comm comm);
#endif
/** /**
* @brief Initialize the rocSHMEM runtime and underlying transport layer * @brief Initialize the rocSHMEM runtime and underlying transport layer
...@@ -327,6 +342,13 @@ __host__ void rocshmem_quiet(); ...@@ -327,6 +342,13 @@ __host__ void rocshmem_quiet();
*/ */
__host__ void rocshmem_barrier_all(); __host__ void rocshmem_barrier_all();
/**
* @brief enqueues a collective barrier on given stream.
*
* @return void
*/
__host__ void rocshmem_barrier_all_on_stream(hipStream_t stream);
/** /**
* @brief registers the arrival of a PE at a barrier. * @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved. * The caller is blocked until the synchronization is resolved.
...@@ -360,7 +382,7 @@ __host__ void rocshmem_global_exit(int status); ...@@ -360,7 +382,7 @@ __host__ void rocshmem_global_exit(int status);
* *
* @return void. * @return void.
*/ */
__device__ void rocshmem_wg_init(); [[deprecated]] __device__ void rocshmem_wg_init();
/** /**
* @brief Finalizes device-side rocSHMEM resources. Must be called before * @brief Finalizes device-side rocSHMEM resources. Must be called before
...@@ -370,7 +392,7 @@ __device__ void rocshmem_wg_init(); ...@@ -370,7 +392,7 @@ __device__ void rocshmem_wg_init();
* *
* @return void. * @return void.
*/ */
__device__ void rocshmem_wg_finalize(); [[deprecated]] __device__ void rocshmem_wg_finalize();
/** /**
* @brief Initializes device-side rocSHMEM resources. Must be called before * @brief Initializes device-side rocSHMEM resources. Must be called before
...@@ -386,7 +408,7 @@ __device__ void rocshmem_wg_finalize(); ...@@ -386,7 +408,7 @@ __device__ void rocshmem_wg_finalize();
* *
* @return void. * @return void.
*/ */
__device__ void rocshmem_wg_init_thread(int requested, int *provided); [[deprecated]] __device__ void rocshmem_wg_init_thread(int requested, int *provided);
/** /**
* @brief Query the thread mode used by the runtime. * @brief Query the thread mode used by the runtime.
...@@ -476,6 +498,23 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_quiet(rocshmem_ctx_t ctx); ...@@ -476,6 +498,23 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_quiet(rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_quiet(); __device__ ATTR_NO_INLINE void rocshmem_quiet();
/**
* @brief Completes all previous operations posted to this context for PEs in the
* `target_pes` array.
*
* @param[in] ctx Context with which to perform this operation.
*
* @param[in] target_pes Address of target PE array where the operations need to be completed.
*
* @param[in] npes The number of PEs in the target PE array.
*
* @return void.
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_pe_quiet(rocshmem_ctx_t ctx, const int *target_pes, size_t npes);
__device__ ATTR_NO_INLINE void rocshmem_pe_quiet(const int *target_pes, size_t npes);
/** /**
* @brief Query the total number of PEs. * @brief Query the total number of PEs.
* *
......
...@@ -599,6 +599,14 @@ __host__ int rocshmem_ctx_double_prod_reduce( ...@@ -599,6 +599,14 @@ __host__ int rocshmem_ctx_double_prod_reduce(
rocshmem_ctx_t ctx, rocshmem_team_t team, double *dest, const double *source, rocshmem_ctx_t ctx, rocshmem_team_t team, double *dest, const double *source,
int nreduce); int nreduce);
/**
* @brief kernel for performing a barrier synchronization.
* Caller enqueues the kernel on given stream
*
* @return void
*/
__global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel();
/** /**
* @brief perform a collective barrier between all PEs in the system. * @brief perform a collective barrier between all PEs in the system.
* The caller is blocked until the barrier is resolved. * The caller is blocked until the barrier is resolved.
...@@ -767,28 +775,6 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_sync_wave( ...@@ -767,28 +775,6 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_sync_wave(
__device__ ATTR_NO_INLINE void rocshmem_ctx_sync_wg( __device__ ATTR_NO_INLINE void rocshmem_ctx_sync_wg(
rocshmem_ctx_t ctx, rocshmem_team_t team); rocshmem_ctx_t ctx, rocshmem_team_t team);
/**
* @brief Query a local pointer to a symmetric data object on the
* specified \pe . Returns an address that may be used to directly reference
* dest on the specified \pe. This address can be accesses with LD/ST ops.
*
* Can be called per thread with no performance penalty.
*/
__device__ ATTR_NO_INLINE void *rocshmem_ptr(const void *dest, int pe);
/**
* @brief Make all uncacheable GPU data visible to other agents in the sytem.
*
* This only works for data that was explicitly allocated uncacheable on the
* GPU!
*
* Can be called per thread with no performance penalty.
*
* @param[in] GPU-side handle.
*
* @return void
*/
} // namespace rocshmem } // namespace rocshmem
#endif // LIBRARY_INCLUDE_ROCSHMEM_COLL_HPP #endif // LIBRARY_INCLUDE_ROCSHMEM_COLL_HPP
...@@ -106,9 +106,18 @@ const int ROCSHMEM_CTX_SHARED = 8; ...@@ -106,9 +106,18 @@ const int ROCSHMEM_CTX_SHARED = 8;
* @brief GPU side OpenSHMEM context created from each work-groups' * @brief GPU side OpenSHMEM context created from each work-groups'
* rocshmem_wg_handle_t * rocshmem_wg_handle_t
*/ */
typedef struct { typedef struct rocshmem_ctx{
void *ctx_opaque; void *ctx_opaque;
void *team_opaque; void *team_opaque;
__host__ __device__ bool operator==(const struct rocshmem_ctx& other) const {
return (ctx_opaque == other.ctx_opaque &&
team_opaque == other.team_opaque);
}
__host__ __device__ bool operator!=(const struct rocshmem_ctx& other) const {
return !(*this == other);
}
} rocshmem_ctx_t; } rocshmem_ctx_t;
/** /**
...@@ -116,6 +125,14 @@ typedef struct { ...@@ -116,6 +125,14 @@ typedef struct {
*/ */
extern "C" __device__ rocshmem_ctx_t __attribute__((visibility("default"))) ROCSHMEM_CTX_DEFAULT; extern "C" __device__ rocshmem_ctx_t __attribute__((visibility("default"))) ROCSHMEM_CTX_DEFAULT;
/**
* A value corresponding to an invalid communication context. This value can be
* used to initialize or update context handles to indicate that they do not
* reference a valid context. When managed in this way, applications can use an
* equality comparison to test whether a given context handle references a
* valid context.
*/
extern __constant__ rocshmem_ctx_t ROCSHMEM_CTX_INVALID;
/** /**
* Used internally to set default context. * Used internally to set default context.
*/ */
......
...@@ -45,3 +45,4 @@ ...@@ -45,3 +45,4 @@
/* #undef GDA_IONIC */ /* #undef GDA_IONIC */
/* #undef GDA_BNXT */ /* #undef GDA_BNXT */
#define GDA_MLX5 #define GDA_MLX5
#define HAVE_EXTERNAL_MPI
/******************************************************************************
* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#ifndef LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
#define LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
#if defined(HAVE_EXTERNAL_MPI)
#include <mpi.h>
#endif
#if defined(c_plusplus) || defined(__cplusplus)
extern "C" {
#endif
#if !defined(MPI_VERSION)
// Open MPI based values for the constants/handles etc.
// Even though we did not include an external MPI header file
// The includer may have (e.g., a unit test).
typedef void* MPI_Comm;
typedef void* MPI_Win;
typedef void* MPI_Group;
typedef void* MPI_Op;
typedef void* MPI_Datatype;
typedef void* MPI_Request;
typedef void* MPI_Info;
struct ompi_status_public_t {
int MPI_SOURCE;
int MPI_TAG;
int MPI_ERROR;
int _cancelled;
size_t _ucount;
};
typedef struct ompi_status_public_t MPI_Status;
#define MPI_Aint uint64_t
#define MPI_UNDEFINED -32766
#define MPI_THREAD_MULTIPLE 3
#define MPI_SUCCESS 0
#define MPI_IN_PLACE (void*)1
#define MPI_MODE_NOCHECK 1
#define MPI_COMM_TYPE_SHARED 0
#define MPI_Aint_diff(addr1, addr2) ((MPI_Aint) ((char *) (addr1) - (char *) (addr2)))
struct ompi_internal_symbols_t {
void *ompi_mpi_comm_world;
void *ompi_mpi_comm_null;
void *ompi_request_null;
void *ompi_mpi_info_null;
void *ompi_mpi_datatype_null;
void *ompi_mpi_op_max;
void *ompi_mpi_op_min;
void *ompi_mpi_op_sum;
void *ompi_mpi_op_prod;
void *ompi_mpi_op_band;
void *ompi_mpi_op_bor;
void *ompi_mpi_op_bxor;
void *ompi_mpi_op_replace;
void *ompi_mpi_op_no_op;
void *ompi_mpi_char;
void *ompi_mpi_unsigned_char;
void *ompi_mpi_signed_char;
void *ompi_mpi_short;
void *ompi_mpi_unsigned_short;
void *ompi_mpi_int;
void *ompi_mpi_unsigned;
void *ompi_mpi_long;
void *ompi_mpi_unsigned_long;
void *ompi_mpi_long_long_int;
void *ompi_mpi_unsigned_long_long;
void *ompi_mpi_float;
void *ompi_mpi_double;
void *ompi_mpi_long_double;
};
extern struct ompi_internal_symbols_t ompi_symbols_;
#define OMPI_PREDEFINED_GLOBAL(type, global) (static_cast<type> (global))
#define MPI_COMM_WORLD OMPI_PREDEFINED_GLOBAL(MPI_Comm, ompi_symbols_.ompi_mpi_comm_world)
#define MPI_COMM_NULL OMPI_PREDEFINED_GLOBAL(MPI_Comm, ompi_symbols_.ompi_mpi_comm_null)
#define MPI_REQUEST_NULL OMPI_PREDEFINED_GLOBAL(MPI_Request, ompi_symbols_.ompi_request_null)
#define MPI_WIN_NULL OMPI_PREDEFINED_GLOBAL(MPI_Win, ompi_symbols_.ompi_mpi_win_null)
#define MPI_INFO_NULL OMPI_PREDEFINED_GLOBAL(MPI_Info, ompi_symbols_.ompi_mpi_info_null)
#define MPI_MAX OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_max)
#define MPI_MIN OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_min)
#define MPI_SUM OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_sum)
#define MPI_PROD OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_prod)
#define MPI_BAND OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_band)
#define MPI_BOR OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_bor)
#define MPI_BXOR OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_bxor)
#define MPI_REPLACE OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_replace)
#define MPI_NO_OP OMPI_PREDEFINED_GLOBAL(MPI_Op, ompi_symbols_.ompi_mpi_op_no_op)
#define MPI_DATATYPE_NULL OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_datatype_null)
#define MPI_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_char)
#define MPI_UNSIGNED_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_char)
#define MPI_SIGNED_CHAR OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_signed_char)
#define MPI_SHORT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_short)
#define MPI_UNSIGNED_SHORT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_short)
#define MPI_INT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_int)
#define MPI_UNSIGNED OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned)
#define MPI_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long)
#define MPI_UNSIGNED_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_long)
#define MPI_LONG_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long_long_int)
#define MPI_UNSIGNED_LONG_LONG OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_unsigned_long_long)
#define MPI_FLOAT OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_float)
#define MPI_DOUBLE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_double)
#define MPI_LONG_DOUBLE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, ompi_symbols_.ompi_mpi_long_double)
#endif //!defined(MPI_VERSION)
#if defined(c_plusplus) || defined(__cplusplus)
}
#endif
#endif //LIBRARY_INCLUDE_ROCSHMEM_MPI_HPP
...@@ -61,7 +61,7 @@ add_library(roc::rocshmem STATIC IMPORTED) ...@@ -61,7 +61,7 @@ add_library(roc::rocshmem STATIC IMPORTED)
set_target_properties(roc::rocshmem PROPERTIES set_target_properties(roc::rocshmem PROPERTIES
INTERFACE_COMPILE_OPTIONS "-fgpu-rdc;-fgpu-rdc" INTERFACE_COMPILE_OPTIONS "-fgpu-rdc;-fgpu-rdc"
INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include;${_IMPORT_PREFIX}/include" INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include;${_IMPORT_PREFIX}/include"
INTERFACE_LINK_LIBRARIES "IBVerbs::verbs;numa;Threads::Threads;MPI::MPI_CXX;hip::device;hip::host;hsa-runtime64::hsa-runtime64;-fgpu-rdc" INTERFACE_LINK_LIBRARIES "IBVerbs::verbs;numa;\$<\$<BOOL:ON>:MPI::MPI_CXX>;Threads::Threads;hip::device;hip::host;dl;hsa-runtime64::hsa-runtime64;-fgpu-rdc"
) )
# Load information for each installed configuration. # Load information for each installed configuration.
......
../../../lib/cmake/rocshmem/rocshmem-config-version.cmake # This is a basic version file for the Config-mode of find_package().
\ No newline at end of file # It is used by write_basic_package_version_file() as input file for configure_file()
# to create a version-file which can be installed along a config.cmake file.
#
# The created file sets PACKAGE_VERSION_EXACT if the current version string and
# the requested version string are exactly the same and it sets
# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version,
# but only if the requested major version is the same as the current one.
# The variable CVF_VERSION must be set before calling configure_file().
set(PACKAGE_VERSION "3.0.0")
if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION)
set(PACKAGE_VERSION_COMPATIBLE FALSE)
else()
if("3.0.0" MATCHES "^([0-9]+)\\.")
set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}")
if(NOT CVF_VERSION_MAJOR VERSION_EQUAL 0)
string(REGEX REPLACE "^0+" "" CVF_VERSION_MAJOR "${CVF_VERSION_MAJOR}")
endif()
else()
set(CVF_VERSION_MAJOR "3.0.0")
endif()
if(PACKAGE_FIND_VERSION_RANGE)
# both endpoints of the range must have the expected major version
math (EXPR CVF_VERSION_MAJOR_NEXT "${CVF_VERSION_MAJOR} + 1")
if (NOT PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
OR ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX_MAJOR STREQUAL CVF_VERSION_MAJOR)
OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX VERSION_LESS_EQUAL CVF_VERSION_MAJOR_NEXT)))
set(PACKAGE_VERSION_COMPATIBLE FALSE)
elseif(PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
AND ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_LESS_EQUAL PACKAGE_FIND_VERSION_MAX)
OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MAX)))
set(PACKAGE_VERSION_COMPATIBLE TRUE)
else()
set(PACKAGE_VERSION_COMPATIBLE FALSE)
endif()
else()
if(PACKAGE_FIND_VERSION_MAJOR STREQUAL CVF_VERSION_MAJOR)
set(PACKAGE_VERSION_COMPATIBLE TRUE)
else()
set(PACKAGE_VERSION_COMPATIBLE FALSE)
endif()
if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION)
set(PACKAGE_VERSION_EXACT TRUE)
endif()
endif()
endif()
# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it:
if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "")
return()
endif()
# check that the installed version has the same 32/64bit-ness as the one which is currently searching:
if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8")
math(EXPR installedBits "8 * 8")
set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)")
set(PACKAGE_VERSION_UNSUITABLE TRUE)
endif()
../../../lib/cmake/rocshmem/rocshmem-config.cmake
\ No newline at end of file
####################################################################################
# Auto generated @PACKAGE_INIT@ by rocm_configure_package_config_file()
# from rocshmem-config.cmake.in
# Any changes to this file will be overwritten by the next CMake run
####################################################################################
get_filename_component(_ROCM_CMAKE_CURRENT_LIST_FILE_REAL "${CMAKE_CURRENT_LIST_FILE}" REALPATH)
get_filename_component(_ROCM_CMAKE_CURRENT_LIST_DIR_REAL "${_ROCM_CMAKE_CURRENT_LIST_FILE_REAL}" DIRECTORY)
get_filename_component(PACKAGE_PREFIX_DIR "${_ROCM_CMAKE_CURRENT_LIST_DIR_REAL}/../../../" ABSOLUTE)
macro(set_and_check _var _file)
set(${_var} "${_file}")
if(NOT EXISTS "${_file}")
message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !")
endif()
endmacro()
include(CMakeFindDependencyMacro OPTIONAL RESULT_VARIABLE _ROCMCMakeFindDependencyMacro_FOUND)
if (NOT _ROCMCMakeFindDependencyMacro_FOUND)
macro(find_dependency dep)
if (NOT ${dep}_FOUND)
set(rocm_fd_version)
if (${ARGC} GREATER 1)
set(rocm_fd_version ${ARGV1})
endif()
set(rocm_fd_exact_arg)
if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION_EXACT)
set(rocm_fd_exact_arg EXACT)
endif()
set(rocm_fd_quiet_arg)
if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
set(rocm_fd_quiet_arg QUIET)
endif()
set(rocm_fd_required_arg)
if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
set(rocm_fd_required_arg REQUIRED)
endif()
find_package(${dep} ${rocm_fd_version}
${rocm_fd_exact_arg}
${rocm_fd_quiet_arg}
${rocm_fd_required_arg}
)
string(TOUPPER ${dep} cmake_dep_upper)
if (NOT ${dep}_FOUND AND NOT ${cmake_dep_upper}_FOUND)
set(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE
"${CMAKE_FIND_PACKAGE_NAME} could not be found because dependency ${dep} could not be found.")
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND False)
return()
endif()
set(rocm_fd_version)
set(rocm_fd_required_arg)
set(rocm_fd_quiet_arg)
set(rocm_fd_exact_arg)
endif()
endmacro()
endif()
macro(check_required_components _NAME)
foreach(comp ${${_NAME}_FIND_COMPONENTS})
if(NOT ${_NAME}_${comp}_FOUND)
if(${_NAME}_FIND_REQUIRED_${comp})
set(${_NAME}_FOUND FALSE)
endif()
endif()
endforeach()
endmacro()
####################################################################################
set_and_check(rocshmem_INCLUDE_DIR ${PACKAGE_PREFIX_DIR}/include)
set_and_check(rocshmem_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
set_and_check(ROCSHMEM_INCLUDE_DIR ${PACKAGE_PREFIX_DIR}/include)
set_and_check(ROCSHMEM_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
set_and_check(rocshmem_INCLUDE_DIR ${PACKAGE_PREFIX_DIR}/include)
set_and_check(rocshmem_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
set_and_check(rocshmem_TARGET_FILE ${PACKAGE_PREFIX_DIR}/lib/cmake/rocshmem/rocshmem-targets.cmake)
include(${rocshmem_TARGET_FILE})
set(rocshmem_LIBRARIES roc::rocshmem)
set(rocshmem_LIBRARY roc::rocshmem)
set(ROCSHMEM_LIBRARIES roc::rocshmem)
set(ROCSHMEM_LIBRARY roc::rocshmem)
set(rocshmem_LIBRARIES roc::rocshmem)
set(rocshmem_LIBRARY roc::rocshmem)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment