Commit 09cb2b03 authored by lishen's avatar lishen
Browse files

添加low latency接口,正确性需补充

parent 0b14d3b2
...@@ -8,12 +8,17 @@ fi ...@@ -8,12 +8,17 @@ fi
PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])") PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])")
PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])") PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])")
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/intranode.cu -o build_/intranode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 INCLUDE_PATHS=${INCLUDE_PATHS:=-Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE}}
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/runtime.cu -o build_/runtime.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17}
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/layout.cu -o build_/layout.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/deep_ep.cu -o build_/deep_ep.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/runtime.cu -o build_/runtime.o ${COMPILE_OPTIONS}
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/internode.cu -o build_/internode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17 hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/layout.cu -o build_/layout.o ${COMPILE_OPTIONS}
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o -L$(pwd)/rocshmem_dir/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,$(pwd)/rocshmem_dir/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5 hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/intranode.cu -o build_/intranode.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode_ll.cu -o build_/internode_ll.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/deep_ep.cu -o build_/deep_ep.o ${COMPILE_OPTIONS}
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L$(pwd)/rocshmem_dir/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,$(pwd)/rocshmem_dir/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5
# build whl # build whl
echo "Using Python: $(which python3)" echo "Using Python: $(which python3)"
......
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#pragma once #pragma once
#include "kernels/api.cuh" #include "kernels/api.cuh"
...@@ -105,18 +107,18 @@ struct Config { ...@@ -105,18 +107,18 @@ struct Config {
struct LowLatencyBuffer { struct LowLatencyBuffer {
int num_clean_int = 0; int num_clean_int = 0;
void *dispatch_rdma_send_buffer = nullptr; void* dispatch_rdma_send_buffer = nullptr;
void *dispatch_rdma_recv_data_buffer = nullptr; void* dispatch_rdma_recv_data_buffer = nullptr;
int *dispatch_rdma_recv_count_buffer = nullptr; int64_t* dispatch_rdma_recv_count_buffer = nullptr;
void *combine_rdma_send_buffer = nullptr; void* combine_rdma_send_buffer = nullptr;
void *combine_rdma_recv_data_buffer = nullptr; void* combine_rdma_recv_data_buffer = nullptr;
int *combine_rdma_recv_flag_buffer = nullptr; int64_t* combine_rdma_recv_flag_buffer = nullptr;
void *combine_rdma_send_buffer_data_start = nullptr; void* combine_rdma_send_buffer_data_start = nullptr;
size_t num_bytes_per_combine_msg = 0; size_t num_bytes_per_combine_msg = 0;
std::pair<int *, int> clean_meta() { std::pair<int64_t*, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int}; return {dispatch_rdma_recv_count_buffer, num_clean_int};
} }
...@@ -171,29 +173,30 @@ struct LowLatencyLayout { ...@@ -171,29 +173,30 @@ struct LowLatencyLayout {
total_bytes += recv_buffer_bytes * 2; total_bytes += recv_buffer_bytes * 2;
// Symmetric signaling buffers // Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); total_bytes += signaling_buffer_bytes * 2;
size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
total_bytes += signaling_buffer_bytes_aligned * 2;
// Assign pointers // Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer, // NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// so you may see some parameters are duplicated // so you may see some parameters are duplicated
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
buffers[i] = { buffers[i] = {
static_cast<int>(signaling_buffer_bytes / sizeof(int)), static_cast<int>(signaling_buffer_bytes / sizeof(int64_t)),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), // dispatch:send_buffer + recv_buffer + recv_count
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + advance(rdma_buffer, send_buffer_bytes * i),
recv_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i), advance<int64_t*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), // combine:send_buffer + recv_buffer + recv_count
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + advance(rdma_buffer, send_buffer_bytes * i),
recv_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i), advance<int64_t*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), // combine_rdma_send_buffer_data_start
num_bytes_per_combine_msg}; advance(rdma_buffer, send_buffer_bytes * i + sizeof(int4)),
//
num_bytes_per_combine_msg
};
} }
} }
}; };
......
This diff is collapsed.
...@@ -30,6 +30,11 @@ private: ...@@ -30,6 +30,11 @@ private:
int64_t num_rdma_bytes; int64_t num_rdma_bytes;
void *rdma_buffer_ptr = nullptr; void *rdma_buffer_ptr = nullptr;
// Shrink mode buffer
bool enable_shrink = false;
int* mask_buffer_ptr = nullptr;
int* sync_buffer_ptr = nullptr;
// Device info and communication // Device info and communication
int device_id; int device_id;
int num_device_sms; int num_device_sms;
...@@ -67,11 +72,9 @@ private: ...@@ -67,11 +72,9 @@ private:
volatile int *moe_recv_rdma_counter = nullptr; volatile int *moe_recv_rdma_counter = nullptr;
int *moe_recv_rdma_counter_mapped = nullptr; int *moe_recv_rdma_counter_mapped = nullptr;
bool use_default_stream_as_comm_stream = false;
public: public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool use_default_stream_as_comm_stream); bool low_latency_mode, bool explicitly_destroy, bool enable_shrink);
~Buffer() noexcept(false); ~Buffer() noexcept(false);
...@@ -187,6 +190,12 @@ public: ...@@ -187,6 +190,12 @@ public:
torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank,
int hidden, int num_experts) const; int hidden, int num_experts) const;
void low_latency_update_mask_buffer(int rank_to_mask, bool mask);
void low_latency_query_mask_buffer(const torch::Tensor& mask_status);
void low_latency_clean_mask_buffer();
}; };
} // namespace deep_ep } // namespace deep_ep
...@@ -131,4 +131,46 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights, ...@@ -131,4 +131,46 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode); int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode);
} // namespace internode } // namespace internode
// Internode low-latency kernels
namespace internode_ll {
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1,
int rank, int num_ranks,
int* mask_buffer, int* sync_buffer, hipStream_t stream);
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count,
int* global_atomic_counter,
int* mask_buffer, int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int num_device_sms, hipStream_t stream, int phases);
void combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* global_atomic_counter,
int* mask_buffer, int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int, int num_combined_tokens,
int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, bool use_logfmt,
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy);
void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* output_mask_tensor, hipStream_t stream);
void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask, hipStream_t stream);
void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, hipStream_t stream);
} // namespace internode_ll
} // namespace deep_ep } // namespace deep_ep
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#define LOW_LATENCY_SEND_PHASE 1 #define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2 #define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3 #define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
......
This diff is collapsed.
...@@ -125,6 +125,10 @@ __device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) { ...@@ -125,6 +125,10 @@ __device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
__hip_atomic_store(const_cast<int *>(ptr), val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_SYSTEM); __hip_atomic_store(const_cast<int *>(ptr), val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_SYSTEM);
} }
__device__ __forceinline__ void st_release_sys_global(const int64_t *ptr, int64_t val) {
__hip_atomic_store(const_cast<int64_t *>(ptr), val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_SYSTEM);
}
__device__ __forceinline__ void st_release_cta(const int *ptr, int val) { __device__ __forceinline__ void st_release_cta(const int *ptr, int val) {
__hip_atomic_store(const_cast<int *>(ptr), val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_WORKGROUP); __hip_atomic_store(const_cast<int *>(ptr), val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_WORKGROUP);
} }
...@@ -157,6 +161,12 @@ __device__ __forceinline__ int ld_acquire_global(const int *ptr) { ...@@ -157,6 +161,12 @@ __device__ __forceinline__ int ld_acquire_global(const int *ptr) {
return ret; return ret;
} }
__device__ __forceinline__ int64_t ld_acquire_global(const int64_t *ptr) {
int64_t ret;
ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT);
return ret;
}
__device__ __forceinline__ int atomic_add_release_global(const int *ptr, int value) { __device__ __forceinline__ int atomic_add_release_global(const int *ptr, int value) {
int ret; int ret;
// ret = __hip_atomic_fetch_add(const_cast<int *>(ptr), value, __ATOMIC_RELEASE, // ret = __hip_atomic_fetch_add(const_cast<int *>(ptr), value, __ATOMIC_RELEASE,
...@@ -165,6 +175,12 @@ __device__ __forceinline__ int atomic_add_release_global(const int *ptr, int val ...@@ -165,6 +175,12 @@ __device__ __forceinline__ int atomic_add_release_global(const int *ptr, int val
return ret; return ret;
} }
__device__ __forceinline__ int ld_relaxed_global(const int *ptr) {
int ret;
ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
return ret;
}
__device__ __forceinline__ int ld_acquire_cta(const int *ptr) { __device__ __forceinline__ int ld_acquire_cta(const int *ptr) {
int ret; int ret;
ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_WORKGROUP); ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_WORKGROUP);
...@@ -245,6 +261,11 @@ __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) ...@@ -245,6 +261,11 @@ __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val)
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT); __hip_atomic_store(non_const_ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
} }
__device__ __forceinline__ void st_na_release(const int64_t *ptr, int64_t val) {
int64_t *non_const_ptr = const_cast<int64_t *>(ptr);
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
}
// TODO:: apply "st.global.L1::no_allocate" in ROCM // TODO:: apply "st.global.L1::no_allocate" in ROCM
template <typename dtype_t> template <typename dtype_t>
__device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t &value) { __device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t &value) {
...@@ -279,6 +300,22 @@ __forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_s ...@@ -279,6 +300,22 @@ __forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_s
token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens); token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);
} }
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
dtype_b_t packed;
auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
unpacked_ptr[0] = x, unpacked_ptr[1] = y;
return packed;
}
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
x = unpacked_ptr[0], y = unpacked_ptr[1];
}
template <typename dtype_t> template <typename dtype_t>
__device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) { __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, ""); EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, "");
...@@ -290,15 +327,47 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) { ...@@ -290,15 +327,47 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t *>(recv_int_values); return *reinterpret_cast<dtype_t *>(recv_int_values);
} }
__forceinline__ __device__ int warp_reduce_sum(int value) { #ifdef USE_ROCM
if constexpr (kWarpSize == 64) constexpr float kFP8Margin = 1e-4;
value += shfl_xor<int>(value, 32); constexpr float kFinfoAmaxE4M3 = 240.0f;
value += shfl_xor<int>(value, 16); constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
value += shfl_xor<int>(value, 8); #else
value += shfl_xor<int>(value, 4); constexpr float kFP8Margin = 1e-4;
value += shfl_xor<int>(value, 2); constexpr float kFinfoAmaxE4M3 = 448.0f;
value += shfl_xor<int>(value, 1); constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
return value; #endif
__forceinline__ __device__ float fast_pow2(int x) {
// We can ensure `-126 <= x and x <= 127`
uint32_t bits_x = (x + 127) << 23;
return *reinterpret_cast<float*>(&bits_x);
}
__forceinline__ __device__ int fast_log2_ceil(float x) {
auto bits_x = *reinterpret_cast<uint32_t*>(&x);
auto exp_x = (bits_x >> 23) & 0xff;
auto man_bits = bits_x & ((1 << 23) - 1);
return exp_x - 127 + (man_bits != 0);
}
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) {
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE4M3;
scale = kFinfoAmaxE4M3 / amax;
}
}
template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
__forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {
if constexpr (kIsUE8M0) {
return static_cast<uint8_t>((*reinterpret_cast<uint32_t*>(&value)) >> 23);
} else {
return value;
}
} }
__forceinline__ __device__ int get_lane_id() { __forceinline__ __device__ int get_lane_id() {
...@@ -340,4 +409,95 @@ __forceinline__ __device__ void barrier_block(int **barrier_signal_ptrs, int ran ...@@ -340,4 +409,95 @@ __forceinline__ __device__ void barrier_block(int **barrier_signal_ptrs, int ran
} }
__syncthreads(); __syncthreads();
} }
// Operation functors
template <typename T>
struct ReduceSum {
__device__ T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct ReduceMax {
__device__ T operator()(T a, T b) const { return a > b ? a : b; }
};
template <typename T>
struct ReduceMin {
__device__ T operator()(T a, T b) const { return a < b ? a : b; }
};
template <typename T>
struct ReduceAnd {
__device__ T operator()(T a, T b) const { return a & b; }
};
template <typename T>
struct ReduceOr {
__device__ T operator()(T a, T b) const { return a | b; }
};
// Unified reduction function
template <int kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
__forceinline__ __device__ T warp_reduce(T value, Op op) {
EP_STATIC_ASSERT(kNumLanesPerGroup == kWarpSize or kNumLanesPerGroup == 32 or
kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or kNumLanesPerGroup == 4 or
kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
"Invalid number of lanes");
constexpr uint32_t mask = 0xffffffff;
if constexpr (kIntergroupReduce) {
if constexpr (kNumLanesPerGroup <= 1)
value = op(value, shfl_xor(value, 1));
if constexpr (kNumLanesPerGroup <= 2)
value = op(value, shfl_xor(value, 2));
if constexpr (kNumLanesPerGroup <= 4)
value = op(value, shfl_xor(value, 4));
if constexpr (kNumLanesPerGroup <= 8)
value = op(value, shfl_xor(value, 8));
if constexpr (kNumLanesPerGroup <= 16)
value = op(value, shfl_xor(value, 16));
if constexpr(kWarpSize == 64){
if constexpr (kNumLanesPerGroup <= 32)
value = op(value, shfl_xor(value, 32));
}
} else {
if constexpr(kWarpSize == 64){
if constexpr (kNumLanesPerGroup >= kWarpSize)
value = op(value, shfl_xor(value, 32));
}
if constexpr (kNumLanesPerGroup >= 32)
value = op(value, shfl_xor(value, 16));
if constexpr (kNumLanesPerGroup >= 16)
value = op(value, shfl_xor(value, 8));
if constexpr (kNumLanesPerGroup >= 8)
value = op(value, shfl_xor(value, 4));
if constexpr (kNumLanesPerGroup >= 4)
value = op(value, shfl_xor(value, 2));
if constexpr (kNumLanesPerGroup >= 2)
value = op(value, shfl_xor(value, 1));
}
return value;
}
// Convenience aliases
template <int kNumLanesPerGroup = kWarpSize, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_sum(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
}
template <int kNumLanesPerGroup = kWarpSize, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_max(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceMax<T>{});
}
template <int kNumLanesPerGroup = kWarpSize, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_min(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceMin<T>{});
}
template <int kNumLanesPerGroup = kWarpSize, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_and(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceAnd<T>{});
}
template <int kNumLanesPerGroup = kWarpSize, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_or(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceOr<T>{});
}
} // namespace deep_ep } // namespace deep_ep
...@@ -39,7 +39,7 @@ class Buffer: ...@@ -39,7 +39,7 @@ class Buffer:
allow_nvlink_for_low_latency_mode: bool = True, allow_nvlink_for_low_latency_mode: bool = True,
allow_mnnvl: bool = False, allow_mnnvl: bool = False,
explicitly_destroy: bool = False, explicitly_destroy: bool = False,
use_default_stream_as_comm_stream: bool = True, enable_shrink: bool = False,
) -> None: ) -> None:
""" """
Initialize the communication buffer. Initialize the communication buffer.
...@@ -59,6 +59,7 @@ class Buffer: ...@@ -59,6 +59,7 @@ class Buffer:
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources; explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
otherwise, the resources will be released by the destructor. otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang. Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically.
""" """
check_nvlink_connections(group) check_nvlink_connections(group)
...@@ -70,6 +71,7 @@ class Buffer: ...@@ -70,6 +71,7 @@ class Buffer:
self.num_rdma_bytes = num_rdma_bytes self.num_rdma_bytes = num_rdma_bytes
self.low_latency_mode = low_latency_mode self.low_latency_mode = low_latency_mode
self.explicitly_destroy = explicitly_destroy self.explicitly_destroy = explicitly_destroy
self.enable_shrink = enable_shrink
self.runtime = deep_ep_cpp.Buffer( self.runtime = deep_ep_cpp.Buffer(
self.rank, self.rank,
self.group_size, self.group_size,
...@@ -77,7 +79,7 @@ class Buffer: ...@@ -77,7 +79,7 @@ class Buffer:
num_rdma_bytes, num_rdma_bytes,
low_latency_mode, low_latency_mode,
explicitly_destroy, explicitly_destroy,
use_default_stream_as_comm_stream, enable_shrink
) )
# Synchronize device IDs # Synchronize device IDs
...@@ -989,3 +991,31 @@ class Buffer: ...@@ -989,3 +991,31 @@ class Buffer:
return self.runtime.get_next_low_latency_combine_buffer( return self.runtime.get_next_low_latency_combine_buffer(
num_max_dispatch_tokens_per_rank, hidden, num_experts num_max_dispatch_tokens_per_rank, hidden, num_experts
) )
def low_latency_update_mask_buffer(self, rank_to_mask: int, mask: bool = False):
"""
Mask (unmask) a rank during communication (dispatch, combine, and clean)
Arguments:
rank: the rank to mask (unmask).
mask: if True, will mask the rank (do not recvfrom/sendto the rank), otherwise will unmask the rank.
"""
self.runtime.low_latency_update_mask_buffer(rank_to_mask, mask)
def low_latency_query_mask_buffer(self, mask_status: torch.Tensor):
"""
Query the mask status of all ranks
Arguments:
mask_status: `[num_ranks]` with `torch.int`, the mask status of each rank. `1` means mask and `0` means unmasked.
"""
self.runtime.low_latency_query_mask_buffer(mask_status)
def low_latency_clean_mask_buffer(self):
"""
Clean the mask buffer
"""
self.runtime.low_latency_clean_mask_buffer()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment