"csrc/vscode:/vscode.git/clone" did not exist on "7efb944d8e7e640e11f7356aad7120abda4529c7"
Commit 09cb2b03 authored by lishen's avatar lishen
Browse files

添加low latency接口,正确性需补充

parent 0b14d3b2
......@@ -8,12 +8,17 @@ fi
PYTHON_INCLUDE=$(python3 -c "from sysconfig import get_paths; print(get_paths()['include'])")
PYTHON_PLATLIB=$(python3 -c "from sysconfig import get_paths; print(get_paths()['platlib'])")
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/intranode.cu -o build_/intranode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/runtime.cu -o build_/runtime.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/layout.cu -o build_/layout.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/deep_ep.cu -o build_/deep_ep.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
/opt/dtk/bin/hipcc -Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE} -c -c ./csrc/kernels/internode.cu -o build_/internode.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o -L$(pwd)/rocshmem_dir/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,$(pwd)/rocshmem_dir/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5
INCLUDE_PATHS=${INCLUDE_PATHS:=-Icsrc/ -I$(pwd)/rocshmem_dir/include/ -I/opt/mpi/include -I${PYTHON_PLATLIB}/torch/include -I${PYTHON_PLATLIB}/torch/include/torch/csrc/api/include -I${PYTHON_PLATLIB}/torch/include/TH -I${PYTHON_PLATLIB}/torch/include/THC -I${PYTHON_PLATLIB}/torch/include/THH -I/opt/dtk/include -I${PYTHON_INCLUDE}}
COMPILE_OPTIONS=${COMPILE_OPTIONS:= -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -fgpu-rdc -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=deep_ep_cpp -D_GLIBCXX_USE_CXX11_ABI=1 --offload-arch=gfx936 -std=c++17}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/runtime.cu -o build_/runtime.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/layout.cu -o build_/layout.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/intranode.cu -o build_/intranode.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode.cu -o build_/internode.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/kernels/internode_ll.cu -o build_/internode_ll.o ${COMPILE_OPTIONS}
hipcc ${INCLUDE_PATHS} -c $(pwd)/csrc/deep_ep.cu -o build_/deep_ep.o ${COMPILE_OPTIONS}
hipcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions build_/internode.o build_/intranode.o build_/runtime.o build_/deep_ep.o build_/layout.o build_/internode_ll.o -L$(pwd)/rocshmem_dir/lib/ -L/opt/mpi/lib -L/opt/dtk/hip/lib -L/usr/lib/x86_64-linux-gnu -lhipblaslt -lamdhip64 -o deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,/opt/dtk/lib -fgpu-rdc --hip-link --offload-arch=gfx936 -shared -Wl,-soname,deep_ep/deep_ep_cpp.cpython-310-x86_64-linux-gnu.so -Wl,-rpath,$(pwd)/rocshmem_dir/lib/ -L"/opt/dtk/llvm/lib/clang/15.0.0/include/../lib/linux" -lclang_rt.builtins-x86_64 /opt/dtk/hip/lib/libgalaxyhip.so /opt/dtk/llvm/lib/clang/15.0.0/lib/linux/libclang_rt.builtins-x86_64.a /opt/hyhal/lib/libhsa-runtime64.so.1.11.0 -L${PYTHON_PLATLIB}/torch/lib -L/opt/dtk/lib -L/opt/dtk/hip/lib -L/usr/local/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -lrocm-core -lrocm_smi64 -l:librocshmem.a -fgpu-rdc --hip-link -lamdhip64 -lhsa-runtime64 -l:libmpi.so -Wl,-rpath,/opt/mpi/lib/ -libverbs -lmlx5
# build whl
echo "Using Python: $(which python3)"
......
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#pragma once
#include "kernels/api.cuh"
......@@ -105,18 +107,18 @@ struct Config {
struct LowLatencyBuffer {
int num_clean_int = 0;
void *dispatch_rdma_send_buffer = nullptr;
void *dispatch_rdma_recv_data_buffer = nullptr;
int *dispatch_rdma_recv_count_buffer = nullptr;
void* dispatch_rdma_send_buffer = nullptr;
void* dispatch_rdma_recv_data_buffer = nullptr;
int64_t* dispatch_rdma_recv_count_buffer = nullptr;
void *combine_rdma_send_buffer = nullptr;
void *combine_rdma_recv_data_buffer = nullptr;
int *combine_rdma_recv_flag_buffer = nullptr;
void* combine_rdma_send_buffer = nullptr;
void* combine_rdma_recv_data_buffer = nullptr;
int64_t* combine_rdma_recv_flag_buffer = nullptr;
void *combine_rdma_send_buffer_data_start = nullptr;
size_t num_bytes_per_combine_msg = 0;
void* combine_rdma_send_buffer_data_start = nullptr;
size_t num_bytes_per_combine_msg = 0;
std::pair<int *, int> clean_meta() {
std::pair<int64_t*, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int};
}
......@@ -171,29 +173,30 @@ struct LowLatencyLayout {
total_bytes += recv_buffer_bytes * 2;
// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes =
std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
total_bytes += signaling_buffer_bytes_aligned * 2;
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
total_bytes += signaling_buffer_bytes * 2;
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// so you may see some parameters are duplicated
for (int i = 0; i < 2; ++i) {
buffers[i] = {
static_cast<int>(signaling_buffer_bytes / sizeof(int)),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 +
recv_buffer_bytes * i),
advance<int *>(rdma_buffer, signaling_buffer_bytes_aligned * i),
advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i),
num_bytes_per_combine_msg};
static_cast<int>(signaling_buffer_bytes / sizeof(int64_t)),
// dispatch:send_buffer + recv_buffer + recv_count
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int64_t*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
// combine:send_buffer + recv_buffer + recv_count
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int64_t*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
// combine_rdma_send_buffer_data_start
advance(rdma_buffer, send_buffer_bytes * i + sizeof(int4)),
//
num_bytes_per_combine_msg
};
}
}
};
......
// #include <ATen/dtk_macros.h>
#include <ATen/dtk_macros.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/HIPDataType.h>
#include <chrono>
......@@ -13,20 +13,19 @@
namespace deep_ep {
Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy,
bool use_default_stream_as_comm_stream)
bool low_latency_mode, bool explicitly_destroy, bool enable_shrink)
: rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes),
num_rdma_bytes(num_rdma_bytes), low_latency_mode(low_latency_mode),
explicitly_destroy(explicitly_destroy),
use_default_stream_as_comm_stream(use_default_stream_as_comm_stream),
comm_stream(use_default_stream_as_comm_stream
? at::hip::getCurrentHIPStreamMasqueradingAsCUDA()
: at::hip::getStreamFromPoolMasqueradingAsCUDA(true)) {
enable_shrink(enable_shrink),
comm_stream(at::hip::getStreamFromPoolMasqueradingAsCUDA(true)) {
// Metadata memory
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void *);
int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int *);
EP_HOST_ASSERT(enable_shrink == false);
// Common checks
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and
(num_nvl_bytes <= std::numeric_limits<int>::max() or num_rdma_bytes == 0));
......@@ -77,7 +76,7 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
}
// Create 32 MiB workspace
CUDA_CHECK(hipMalloc(&workspace, NUM_WORKSPACE_BYTES));
CUDA_CHECK(hipExtMallocWithFlags(&workspace, NUM_WORKSPACE_BYTES, hipDeviceMallocUncached));
CUDA_CHECK(hipMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream));
// MoE counter
......@@ -200,6 +199,10 @@ void Buffer::destroy() {
CUDA_CHECK(hipDeviceSynchronize());
internode::barrier();
internode::free(rdma_buffer_ptr);
if (enable_shrink) {
internode::free(mask_buffer_ptr);
internode::free(sync_buffer_ptr);
}
internode::finalize();
}
#endif
......@@ -253,25 +256,32 @@ void Buffer::sync(const std::vector<int> &device_
// Sync ROCSHMEM handles and allocate memory
if (num_rdma_bytes > 0) {
// Initialize NVSHMEM
// Initialize ROCSHMEM
EP_HOST_ASSERT(root_unique_id_opt.has_value());
std::vector<uint8_t> root_unique_id(root_unique_id_opt->size());
auto root_unique_id_str = root_unique_id_opt->cast<std::string>();
auto root_unique_id_str = root_unique_id_opt->cast<std::string>();
std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size());
auto nvshmem_rank = low_latency_mode ? rank : rdma_rank;
auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks;
EP_HOST_ASSERT(nvshmem_rank ==
internode::init(
root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode));
EP_HOST_ASSERT(nvshmem_rank == internode::init(root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode));
internode::barrier();
// Allocate
rdma_buffer_ptr =
internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES);
rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES);
// Clean buffer (mainly for low-latency mode)
CUDA_CHECK(hipMemset(rdma_buffer_ptr, 0, num_rdma_bytes));
// Allocate and clean shrink buffer
if (enable_shrink) {
int num_mask_buffer_bytes = num_ranks * sizeof(int);
int num_sync_buffer_bytes = num_ranks * sizeof(int);
mask_buffer_ptr = reinterpret_cast<int*>(internode::alloc(num_mask_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES));
sync_buffer_ptr = reinterpret_cast<int*>(internode::alloc(num_sync_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES));
CUDA_CHECK(hipMemset(mask_buffer_ptr, 0, num_mask_buffer_bytes));
CUDA_CHECK(hipMemset(sync_buffer_ptr, 0, num_sync_buffer_bytes));
}
// Barrier
internode::barrier();
CUDA_CHECK(hipDeviceSynchronize());
......@@ -298,13 +308,11 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
at::hip::setCurrentHIPStreamMasqueradingAsCUDA(comm_stream);
}
if (not use_default_stream_as_comm_stream) {
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
auto num_tokens = static_cast<int>(topk_idx.size(0)),
......@@ -342,9 +350,7 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts,
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
if (not use_default_stream_as_comm_stream) {
stream_wait(compute_stream, comm_stream);
}
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
......@@ -461,12 +467,10 @@ Buffer::intranode_dispatch(
}
// Wait previous tasks to be finished
if (not use_default_stream_as_comm_stream) {
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
// Create handles (only return for non-cached mode)
......@@ -623,9 +627,7 @@ Buffer::intranode_dispatch(
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
if (not use_default_stream_as_comm_stream) {
stream_wait(compute_stream, comm_stream);
}
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
......@@ -691,12 +693,10 @@ Buffer::intranode_combine(const torch::Tensor &x, const std::optional<torch::Ten
}
// Wait previous tasks to be finished
if (not use_default_stream_as_comm_stream) {
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
int num_topk = 0;
......@@ -765,9 +765,7 @@ Buffer::intranode_combine(const torch::Tensor &x, const std::optional<torch::Ten
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
if (not use_default_stream_as_comm_stream) {
stream_wait(compute_stream, comm_stream);
}
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
......@@ -804,8 +802,8 @@ Buffer::internode_dispatch(const torch::Tensor &x, const std::optional<torch::Te
// here.
pybind11::gil_scoped_release release;
const int num_channels = config.num_sms / 3;
EP_HOST_ASSERT(config.num_sms % 3 == 0);
const int num_channels = config.num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
EP_HOST_ASSERT(config.num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS);
bool cached_mode = cached_rdma_channel_prefix_matrix.has_value();
......@@ -1125,8 +1123,8 @@ Buffer::internode_combine(
const torch::Tensor &combined_nvl_head, const Config &config,
std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream) {
#ifndef DISABLE_ROCSHMEM
const int num_channels = config.num_sms / 3;
EP_HOST_ASSERT(config.num_sms % 3 == 0);
const int num_channels = config.num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
EP_HOST_ASSERT(config.num_sms % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0);
// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
......@@ -1272,39 +1270,329 @@ Buffer::internode_combine(
#endif
}
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden,
int num_experts) {
EP_HOST_ASSERT(false and "not support low latency");
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT(low_latency_mode);
auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
auto clean_meta_0 = layout.buffers[0].clean_meta();
auto clean_meta_1 = layout.buffers[1].clean_meta();
auto check_boundary = [=](void* ptr, size_t num_bytes) {
auto offset = reinterpret_cast<int64_t>(ptr) - reinterpret_cast<int64_t>(rdma_buffer_ptr);
EP_HOST_ASSERT(0 <= offset and offset + num_bytes <= num_rdma_bytes);
};
check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int64_t));
check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int64_t));
internode_ll::clean_low_latency_buffer(clean_meta_0.first,
clean_meta_0.second,
clean_meta_1.first,
clean_meta_1.second,
rank,
num_ranks,
mask_buffer_ptr,
sync_buffer_ptr,
at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
#else
EP_HOST_ASSERT(false and "ROCSHMEM is disabled during compilation");
#endif
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor,
std::optional<EventHandle>, std::optional<std::function<void()>>>
std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor &x, const torch::Tensor &topk_idx,
const std::optional<torch::Tensor> &cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor> &dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8,
bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook) {
EP_HOST_ASSERT(false and "not support low latency");
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT(low_latency_mode);
// Tensor checks
// By default using `ptp128c` FP8 cast
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0);
EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(num_experts % num_ranks == 0);
// Diagnosis tensors
if (cumulative_local_expert_recv_stats.has_value()) {
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous());
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks);
}
if (dispatch_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous());
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks);
}
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_topk = static_cast<int>(topk_idx.size(1));
auto num_local_experts = num_experts / num_ranks;
// Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; // 双buffer操作
auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
auto launch_stream = return_recv_hook ? compute_stream : comm_stream;
EP_HOST_ASSERT(not(async and return_recv_hook));
if (not return_recv_hook)
stream_wait(launch_stream, compute_stream);
// Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16));
auto packed_recv_src_info =
torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Allocate column-majored scales
auto packed_recv_x_scales = std::optional<torch::Tensor>();
void* packed_recv_x_scales_ptr = nullptr;
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
// TODO: support unaligned cases
EP_HOST_ASSERT(hidden % 512 == 0);
if (use_fp8) {
if (not use_ue8m0) {
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else {
EP_HOST_ASSERT(round_scale);
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA));
}
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
}
// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
internode_ll::dispatch(
packed_recv_x.data_ptr(),
packed_recv_x_scales_ptr,
packed_recv_src_info.data_ptr<int>(),
packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
global_atomic_counter.data_ptr<int>(),
mask_buffer_ptr,
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
buffer.dispatch_rdma_recv_data_buffer,
buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(),
topk_idx.data_ptr<int64_t>(),
next_clean_meta.first,
next_clean_meta.second,
num_tokens,
hidden,
num_max_dispatch_tokens_per_rank,
num_topk,
num_experts,
rank,
num_ranks,
use_fp8,
round_scale,
use_ue8m0,
workspace,
num_device_sms,
launch_stream,
phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// Return values
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
#else
EP_HOST_ASSERT(false and "ROCSHMEM is disabled during compilation");
return {};
#endif
}
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_combine(const torch::Tensor &x, const torch::Tensor &topk_idx,
const torch::Tensor &topk_weights, const torch::Tensor &src_info,
const torch::Tensor &layout_range,
const std::optional<torch::Tensor> &combine_wait_recv_cost_stats,
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor> &out) {
EP_HOST_ASSERT(false and "not support low latency");
const std::optional<torch::Tensor>& out) {
#ifndef DISABLE_ROCSHMEM
EP_HOST_ASSERT(low_latency_mode);
// Tensor checks
EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks);
EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0);
EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1));
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous());
EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous());
EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0));
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
if (combine_wait_recv_cost_stats.has_value()) {
EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64);
EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous());
EP_HOST_ASSERT(combine_wait_recv_cost_stats->size(0) == num_ranks);
}
auto hidden = static_cast<int>(x.size(2));
auto num_topk = static_cast<int>(topk_weights.size(1));
auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
auto launch_stream = return_recv_hook ? compute_stream : comm_stream;
EP_HOST_ASSERT(not(async and return_recv_hook));
if (not return_recv_hook)
stream_wait(launch_stream, compute_stream);
// Allocate output tensor
torch::Tensor combined_x;
if (out.has_value()) {
EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous());
EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden);
EP_HOST_ASSERT(out->scalar_type() == x.scalar_type());
combined_x = out.value();
} else {
combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
}
// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
internode_ll::combine(combined_x.data_ptr(),
buffer.combine_rdma_recv_data_buffer,
buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_send_buffer,
x.data_ptr(),
topk_idx.data_ptr<int64_t>(),
topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(),
layout_range.data_ptr<int64_t>(),
global_atomic_counter.data_ptr<int>(),
mask_buffer_ptr,
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
next_clean_meta.first,
next_clean_meta.second,
num_combined_tokens,
hidden,
num_max_dispatch_tokens_per_rank,
num_topk,
num_experts,
rank,
num_ranks,
use_logfmt,
workspace,
num_device_sms,
launch_stream,
phases,
zero_copy);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// Return values
return {combined_x, event, recv_hook};
#else
EP_HOST_ASSERT(false and "ROCSHMEM is disabled during compilation");
return {};
#endif
}
torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank,
int hidden, int num_experts) const {
EP_HOST_ASSERT(false and "not support low latency");
torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const {
#ifndef DISABLE_ROCSHMEM
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto dtype = torch::kBFloat16;
auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));
// buffer.num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(hip_bfloat16);
EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0);
return torch::from_blob(buffer.combine_rdma_send_buffer_data_start,
{num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
{num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1},
torch::TensorOptions().dtype(dtype).device(torch::kCUDA));
#else
EP_HOST_ASSERT(false and "ROCSHMEM is disabled during compilation");
return {};
#endif
}
void Buffer::low_latency_update_mask_buffer(int rank_to_mask, bool mask) {
EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled");
EP_HOST_ASSERT(rank_to_mask >= 0 and rank_to_mask < num_ranks);
internode_ll::update_mask_buffer(mask_buffer_ptr, rank_to_mask, mask, at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
void Buffer::low_latency_query_mask_buffer(const torch::Tensor& mask_status) {
EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled");
EP_HOST_ASSERT(mask_status.numel() == num_ranks && mask_status.scalar_type() == torch::kInt32);
internode_ll::query_mask_buffer(
mask_buffer_ptr, num_ranks, reinterpret_cast<int*>(mask_status.data_ptr()), at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
void Buffer::low_latency_clean_mask_buffer() {
EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled");
internode_ll::clean_mask_buffer(mask_buffer_ptr, num_ranks, at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
} // namespace deep_ep
......@@ -1346,8 +1634,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer)
.def("low_latency_update_mask_buffer", &deep_ep::Buffer::low_latency_update_mask_buffer)
.def("low_latency_query_mask_buffer", &deep_ep::Buffer::low_latency_query_mask_buffer)
.def("low_latency_clean_mask_buffer", &deep_ep::Buffer::low_latency_clean_mask_buffer);
// m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
// m.attr("topk_idx_t") = py::cast(c10::CppTypeToScalarType<deep_ep::topk_idx_t>::value);
// m.attr("int64_t") = py::cast(c10::CppTypeToScalarType<deep_ep::int64_t>::value);
}
......@@ -30,6 +30,11 @@ private:
int64_t num_rdma_bytes;
void *rdma_buffer_ptr = nullptr;
// Shrink mode buffer
bool enable_shrink = false;
int* mask_buffer_ptr = nullptr;
int* sync_buffer_ptr = nullptr;
// Device info and communication
int device_id;
int num_device_sms;
......@@ -67,11 +72,9 @@ private:
volatile int *moe_recv_rdma_counter = nullptr;
int *moe_recv_rdma_counter_mapped = nullptr;
bool use_default_stream_as_comm_stream = false;
public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool use_default_stream_as_comm_stream);
bool low_latency_mode, bool explicitly_destroy, bool enable_shrink);
~Buffer() noexcept(false);
......@@ -187,6 +190,12 @@ public:
torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank,
int hidden, int num_experts) const;
void low_latency_update_mask_buffer(int rank_to_mask, bool mask);
void low_latency_query_mask_buffer(const torch::Tensor& mask_status);
void low_latency_clean_mask_buffer();
};
} // namespace deep_ep
......@@ -131,4 +131,46 @@ void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode);
} // namespace internode
// Internode low-latency kernels
namespace internode_ll {
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1,
int rank, int num_ranks,
int* mask_buffer, int* sync_buffer, hipStream_t stream);
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count,
int* global_atomic_counter,
int* mask_buffer, int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int num_device_sms, hipStream_t stream, int phases);
void combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* global_atomic_counter,
int* mask_buffer, int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int, int num_combined_tokens,
int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, bool use_logfmt,
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy);
void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* output_mask_tensor, hipStream_t stream);
void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask, hipStream_t stream);
void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, hipStream_t stream);
} // namespace internode_ll
} // namespace deep_ep
......@@ -22,6 +22,8 @@
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
......
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
// #include "ibgda_device.cuh"
#include "buffer.cuh"
#include "utils.cuh"
// #include <cooperative_groups.h>
#include <iostream>
// low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX
#ifndef DISABLE_ROCSHMEM
#include <rocshmem/rocshmem.hpp>
#include <rocshmem/rocshmem_COLL.hpp>
namespace deep_ep {
namespace internode_ll {
template <bool use_warp_sync = false>
__forceinline__ __device__ bool is_rank_masked(int* mask_buffer_ptr, int rank) {
if (mask_buffer_ptr == nullptr) {
return false;
}
if constexpr (use_warp_sync) {
return shfl_sync(ld_acquire_global(mask_buffer_ptr + rank), 0) != 0;
} else {
return ld_acquire_global(mask_buffer_ptr + rank) != 0;
}
}
__device__ void grid_barrier(int* global_counter, int num_blocks) {
volatile int ret;
__syncthreads();
memory_fence_gpu();
if (threadIdx.x == 0 ) {
ret = atomicAdd((int*)&global_counter[0], 1);
}
__syncthreads();
if (threadIdx.x == 0) {
while (ld_relaxed_global(global_counter) != num_blocks);
}
__syncthreads();
}
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
__global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1) {
__global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1,
int rank, int num_ranks,
int* mask_buffer_ptr, int* sync_buffer_ptr) {
auto thread_id = static_cast<int>(threadIdx.x);
// Barrier before cleaning (in case of unfinished chunked EP)
// nvshmemx_barrier_all_block();
// // Clean
// auto thread_id = static_cast<int>(threadIdx.x);
// #pragma unroll
// for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
// clean_0[i] = 0;
// #pragma unroll
// for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
// clean_1[i] = 0;
// // Barrier after cleaning (make sure the low-latency mode works fine)
// nvshmemx_barrier_all_block();
if (sync_buffer_ptr == nullptr) {
// rocshmem::rocshmem_barrier_all_wg();
if (thread_id == 0)
rocshmem::rocshmem_barrier_all();
} else {
// barrier<kNumThreads>(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr);
EP_DEVICE_ASSERT(0);
}
// Clean
#pragma unroll
for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
clean_0[i] = 0;
#pragma unroll
for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
clean_1[i] = 0;
// Barrier after cleaning (make sure low-latency mode work
if (sync_buffer_ptr == nullptr) {
// rocshmem::rocshmem_barrier_all_wg();
if (thread_id == 0)
rocshmem::rocshmem_barrier_all();
} else {
// barrier<kNumThreads>(thread_id, rank, num_ranks, mask_buffer_ptr, sync_buffer_ptr);
EP_DEVICE_ASSERT(0);
}
}
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
cudaStream_t stream) {
// constexpr int kNumThreads = 256;
// SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
// LAUNCH_KERNEL(&cfg, clean_low_latency_buffer<kNumThreads>,
// clean_0, num_clean_int_0, clean_1, num_clean_int_1);
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
int64_t* clean_1, int num_clean_int_1,
int rank, int num_ranks,
int* mask_buffer_ptr, int* sync_buffer_ptr,
hipStream_t stream) {
constexpr int kNumThreads = 256;
SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
LAUNCH_KERNEL(&cfg, clean_low_latency_buffer<kNumThreads>,
clean_0, num_clean_int_0, clean_1, num_clean_int_1,
rank, num_ranks,
mask_buffer_ptr, sync_buffer_ptr);
}
template <bool kUseFP8, bool kUseUE8M0, int kHidden>
__global__ __launch_bounds__(1024, 1) void
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const topk_idx_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
int* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
bool round_scale, int phases) {
// const auto sm_id = static_cast<int>(blockIdx.x);
// const auto thread_id = static_cast<int>(threadIdx.x);
// const auto warp_id = thread_id / 32, lane_id = get_lane_id();
// const auto num_sms = static_cast<int>(gridDim.x);
// const auto num_warps = num_warp_groups * num_warps_per_group;
// const auto num_local_experts = num_experts / num_ranks;
// const auto warp_group_id = warp_id / num_warps_per_group;
// const auto sub_warp_id = warp_id % num_warps_per_group;
// const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// // May extract UE8M0 from the scales
// using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
// using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
// EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
// // FP8 staffs
// constexpr int kNumPerChannels = 128;
// const int num_scales = kHidden / kNumPerChannels;
// const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16));
// const size_t hidden_int4 = hidden_bytes / sizeof(int4);
// // Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales
// // NOTES: currently we have 3 reserved int fields for future use
// using vec_t = std::conditional_t<kUseFP8, int2, int4>;
// const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16)));
// const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
// EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
// // Expert counts
// constexpr int kNumMaxWarpGroups = 32;
// __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
// // Sending phase
// if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
// goto LOW_LATENCY_DISPATCH_RECV;
// // There are 2 kinds of warps in this part:
// // 1. The first-kind warps for FP8 cast and sending top-k tokens
// // 2. The last warp for reading `topk_idx` and count for per-expert information
// if (warp_id < num_warps - 1) {
// constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16);
// EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerRead) == 0, "Invalid hidden");
// EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization");
// const auto num_threads = (num_warps - 1) * 32;
// const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
// for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
// const auto x_int4 = static_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
// const auto rdma_x_src_idx = reinterpret_cast<int*>(static_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
// const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
// const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
// // Overlap top-k index read and source token index writes
// auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
// thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
// // FP8 cast
// EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce");
// #pragma unroll
// for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// // Read
// auto int4_value = __ldg(x_int4 + i);
// if constexpr (kUseFP8) {
// // Calculate local amax
// auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
// float fp32_values[kNumElemsPerRead];
// float amax = kFP8Margin, scale, scale_inv;
// #pragma unroll
// for (int j = 0; j < kNumElemsPerRead; ++ j) {
// fp32_values[j] = static_cast<float>(bf16_values[j]);
// amax = fmaxf(amax, fabsf(fp32_values[j]));
// }
// // Reduce amax and scale
// EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
// amax = warp_reduce_max<16>(amax);
// calculate_fp8_scales(amax, scale, scale_inv, round_scale);
// if (lane_id == 0 or lane_id == 16)
// rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
// // Cast into send buffer
// vec_t int2_value;
// auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value);
// #pragma unroll
// for (int j = 0; j < kNumElemsPerRead; j += 2) {
// float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
// fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3);
// }
// rdma_x_vec[i] = int2_value;
// } else {
// // Reinterpret-cast is for C++14 compatibility
// rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
// }
// }
// asm volatile("bar.sync 1, %0;" :: "r"(num_threads));
// // Issue IBGDA sends
// if (dst_expert_idx >= 0) {
// int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
// slot_idx = __shfl_sync(0xffffffff, slot_idx, 0);
// const auto dst_rank = dst_expert_idx / num_local_experts;
// const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
// const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
// const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
// dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
// rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
// slot_idx * num_bytes_per_msg;
// const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
// if (dst_p2p_ptr == 0) {
// nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
// } else {
// // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
// const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_p2p_ptr);
// UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
// }
// // Increase counter after finishing
// __syncwarp();
// lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
// }
// }
// } else if (warp_id == num_warps - 1) {
// EP_DEVICE_ASSERT(num_sms > 1);
// if (sm_id == 0) {
// // The first SM is also responsible for checking QPs
// EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts);
// // The first SM is also responsible for cleaning the next buffer
// #pragma unroll
// for (int i = lane_id; i < num_next_clean_int; i += 32)
// next_clean[i] = 0;
// // Notify before executing `int_p`
// __syncwarp();
// #pragma unroll
// for (int i = lane_id; i < num_experts; i += 32)
// atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
// }
// // This SM should be responsible for some destination experts, read `topk_idx` for them
// int expert_count[kNumMaxWarpGroups] = {0};
// const auto expert_begin_idx = sm_id * num_warp_groups;
// const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);
// // Per lane count
// #pragma unroll 8
// for (int i = lane_id; i < num_tokens * num_topk; i += 32) {
// auto idx = static_cast<int>(__ldg(topk_idx + i));
// if (idx >= expert_begin_idx and idx < expert_end_idx)
// expert_count[idx - expert_begin_idx] ++;
// }
// // Warp reduce
// #pragma unroll
// for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
// auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
// if (lane_id == 0) {
// shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
// atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
// }
// }
// }
// __syncthreads();
// // Issue count sends
// if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
// const auto dst_rank = responsible_expert_idx / num_local_experts;
// const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
// const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];
// // Wait local sends issued and send expert counts
// while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
// auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank);
// auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
// if (dst_p2p_ptr == 0) {
// nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
// } else {
// st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), -num_tokens_sent - 1);
// }
// // Clean workspace for next use
// atomic_counter_per_expert[responsible_expert_idx] = 0;
// atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
// // Clean `packed_recv_count`
// if (dst_rank == 0)
// packed_recv_count[dst_expert_local_idx] = 0;
// }
// __syncwarp();
// // Receiving phase
// LOW_LATENCY_DISPATCH_RECV:
// if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
// return;
// // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
// if (phases & LOW_LATENCY_SEND_PHASE)
// cg::this_grid().sync();
// // Receiving and packing
// if (responsible_expert_idx < num_experts) {
// const auto src_rank = responsible_expert_idx / num_local_experts;
// const auto local_expert_idx = responsible_expert_idx % num_local_experts;
// const auto rdma_recv_x_uint8 = static_cast<uint8_t*>(rdma_recv_x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
// src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
// const auto recv_x_int4 = static_cast<int4*>(packed_recv_x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
// const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
// const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
// const auto num_aligned_scales = align_up<int>(num_scales, sizeof(float) / sizeof(scale_t));
// const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
// // Shared between sub-warps in warp groups
// __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
// // Wait tokens to arrive
// // NOTES: using sub-warp 1 to overlap with sub-warp 0
// int num_recv_tokens, recv_token_begin_idx;
// EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15);
// if (sub_warp_id == 1 and lane_id == 0) {
// auto start_time = clock64();
// while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
// auto wait_recv_cost = clock64() - start_time;
// num_recv_tokens = -num_recv_tokens - 1;
// recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
// shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
// shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
// recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
// // Add stats for diagnosis
// if (cumulative_local_expert_recv_stats != nullptr)
// atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);
// if (dispatch_wait_recv_cost_stats != nullptr)
// atomicAdd(reinterpret_cast<unsigned long long*>(dispatch_wait_recv_cost_stats + src_rank), wait_recv_cost);
// }
// asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
// num_recv_tokens = shared_num_recv_tokens[warp_group_id];
// recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
// // Copy tokens
// EP_DEVICE_ASSERT(num_scales <= 64);
// for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
// // Copy source info
// const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
// if (lane_id == 0)
// recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
// __syncwarp();
// // Copy data
// // NOTES: only 2 load iterations for 7K hidden with 7 unrolls
// const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
// const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
// UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// // Copy scales
// if constexpr (kUseFP8) {
// // Equivalent CuTe layout:
// // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
// const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
// const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
// const auto token_idx = recv_token_begin_idx + i;
// const auto token_stride = num_elems_per_pack;
// const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
// if (lane_id < num_scales) {
// const auto pack_idx = lane_id / num_elems_per_pack;
// const auto elem_idx = lane_id % num_elems_per_pack;
// auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
// recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
// }
// if (lane_id + 32 < num_scales) {
// const auto pack_idx = (lane_id + 32) / num_elems_per_pack;
// const auto elem_idx = (lane_id + 32) % num_elems_per_pack;
// auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + 32));
// recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
// }
// }
// }
// }
__launch_bounds__(1024, 1) __global__ void dispatch(void* packed_recv_x,
void* packed_recv_x_scales,
int* packed_recv_src_info,
int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
int* mask_buffer_ptr,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x,
int64_t* rdma_recv_count,
void* rdma_x,
const void* x,
const int64_t* topk_idx,
int* atomic_counter_per_expert,
int* atomic_finish_counter_per_expert,
int64_t* next_clean,
int num_next_clean_int,
int num_tokens,
int num_max_dispatch_tokens_per_rank,
int num_topk,
int num_experts,
int rank,
int num_ranks,
int num_warp_groups,
int num_warps_per_group,
bool round_scale,
int phases) {
#if !defined(ROCM_DISABLE_CTX)
__shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx);
#endif
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
const auto num_sms = static_cast<int>(gridDim.x);
const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_local_experts = num_experts / num_ranks;
const auto warp_group_id = warp_id / num_warps_per_group;
const auto sub_warp_id = warp_id % num_warps_per_group;
// 每个warp处理一个expert
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// May extract UE8M0 from the scales
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
// FP8 staffs
constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL;
const int num_scales = kHidden / kNumPerChannels;
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
const size_t hidden_int4 = hidden_bytes / sizeof(int4);
// Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales
// NOTES: currently we have 3 reserved int fields for future use
using vec_t = std::conditional_t<kUseFP8, int2, int4>;
const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
// Expert counts
constexpr int kNumMaxWarpGroups = 16; // 每个kernel最多warp group数量,即每个block负责的专家数
__shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
#ifdef USE_ROCM
// 用于同步
// 16 is the max possible number of warps in AMD GPUs
constexpr int kMaxNumWarps = 1024 / kWarpSize;
constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0;
}
__syncthreads();
#endif
// Sending phase,如果没有发送任务,则直接跳到接收阶段
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_DISPATCH_RECV;
// There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information
if (warp_id < num_warps - 1) {
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16); // 128/16 = 8
EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerRead) == 0, "Invalid hidden");
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * kWarpSize;
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
const auto x_int4 = static_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
const auto rdma_x_src_idx = reinterpret_cast<int*>(static_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
// Overlap top-k index read and source token index writes
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
// FP8 cast
EP_STATIC_ASSERT(hidden_bf16_int4 % kWarpSize == 0, "Must use the full warp to reduce");
#pragma unroll
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// Read
auto int4_value = __ldg(x_int4 + i);
if constexpr (kUseFP8) {
// Calculate local amax
auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead];
float amax = kFP8Margin, scale, scale_inv;
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++j) {
fp32_values[j] = static_cast<float>(bf16_values[j]);
amax = fmaxf(amax, fabsf(fp32_values[j]));
}
// Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
amax = warp_reduce_max<16>(amax);
calculate_fp8_scales(amax, scale, scale_inv, round_scale);
if (lane_id % 16 == 0)
rdma_x_scales[i * kNumElemsPerRead / FP8_QUANTIZATION_NUM_PER_CHANNEL] = scale_inv;
// Cast into send buffer
vec_t int2_value;
auto fp8x2_values = reinterpret_cast<__hip_fp8x2_storage_t*>(&int2_value);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; j += 2) {
float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
}
rdma_x_vec[i] = int2_value;
} else {
// Reinterpret-cast is for C++14 compatibility
rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
}
}
__syncthreads();
// Issue IBGDA sends
if (dst_expert_idx >= 0) {
int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
slot_idx = shfl_sync(slot_idx, 0);
const int dst_rank = dst_expert_idx / num_local_experts;
const int dst_expert_local_idx = dst_expert_idx % num_local_experts;
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg;
if (dst_rank != rank) {
#if !defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_ctx_schar_put_nbi_wave(ctx,
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank);
rocshmem::rocshmem_ctx_quiet(ctx);
#else
rocshmem::rocshmem_schar_put_nbi_wave(
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank);
rocshmem::rocshmem_fence();
#endif
} else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
}
// Increase counter after finishing
syncwarp();
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
}
}
} else if (warp_id == num_warps - 1) {
EP_DEVICE_ASSERT(num_sms > 1);
if (sm_id == 0) {
// The first SM is also responsible for cleaning the next buffer
#pragma unroll
for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
next_clean[i] = 0;
// Notify before executing `int_p`
syncwarp();
#pragma unroll
for (int i = lane_id; i < num_experts; i += kWarpSize)
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
}
// This SM should be responsible for some destination experts, read `topk_idx` for them
int expert_count[kNumMaxWarpGroups] = {0};
const auto expert_begin_idx = sm_id * num_warp_groups;
const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);
// Per lane count
#pragma unroll 8
for (int i = lane_id; i < num_tokens * num_topk; i += kWarpSize) {
auto idx = static_cast<int>(__ldg(topk_idx + i));
if (idx >= expert_begin_idx and idx < expert_end_idx)
expert_count[idx - expert_begin_idx]++;
}
// Warp reduce
#pragma unroll
for (int i = expert_begin_idx; i < expert_end_idx; ++i) {
auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
if (lane_id == 0) {
shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
}
}
}
__syncthreads();
// Issue count sends
if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
const auto dst_rank = responsible_expert_idx / num_local_experts;
const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (not is_rank_masked(mask_buffer_ptr, dst_rank)) {
auto dst_ptr = reinterpret_cast<int64_t*>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank);
if (dst_rank != rank) {
#if !defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_ctx_long_atomic_add(ctx, dst_ptr, -num_tokens_sent - 1, dst_rank);
#else
rocshmem::rocshmem_long_atomic_add(dst_ptr, -num_tokens_sent - 1, dst_rank);
#endif
} else {
st_release_sys_global(dst_ptr, -num_tokens_sent - 1);
}
}
// Clean workspace for next use
atomic_counter_per_expert[responsible_expert_idx] = 0;
atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
// Clean `packed_recv_count`
if (dst_rank == 0)
packed_recv_count[dst_expert_local_idx] = 0;
}
syncwarp();
// Receiving phase
LOW_LATENCY_DISPATCH_RECV:
// 如果没有接收直接返回
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;
// For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
if (phases & LOW_LATENCY_SEND_PHASE){
grid_barrier(global_atomic_counter, num_sms);
}
// Receiving and packing
if (responsible_expert_idx < num_experts) {
const auto src_rank = responsible_expert_idx / num_local_experts;
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
const auto rdma_recv_x_uint8 = static_cast<uint8_t*>(rdma_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
const auto recv_x_int4 =
static_cast<int4*>(packed_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
const auto num_aligned_scales = ALIGN<int>(num_scales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
// Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
// Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0
int64_t num_recv_tokens;
int recv_token_begin_idx;
EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15);
if (sub_warp_id == 1 and lane_id == 0) {
auto start_time = wall_clock64();
int64_t wait_recv_cost = 0;
int offset = local_expert_idx * num_ranks + src_rank;
if (not is_rank_masked(mask_buffer_ptr, src_rank)) {
while ((wait_recv_cost = wall_clock64() - start_time) <= NUM_TIMEOUT_CYCLES) { // not timeout
if((num_recv_tokens = ld_acquire_global(reinterpret_cast<int64_t*>(
rdma_recv_count + local_expert_idx * num_ranks + src_rank))) != 0) {
break;
}
}
}
// Mask rank if timeout
if (wait_recv_cost > NUM_TIMEOUT_CYCLES) {
printf("Warning: DeepEP timeout for dispatch receive, rank %d, local_expert_idx %d, src_rank %d\n",
rank,
local_expert_idx,
src_rank);
if (mask_buffer_ptr == nullptr)
trap();
atomicExch(mask_buffer_ptr + src_rank, 1);
}
// Do not receive tokens if rank timeout or masked
if (num_recv_tokens == 0)
num_recv_tokens = -1;
#if 1
num_recv_tokens = -num_recv_tokens - 1;
int num_recv_tokens_int32 = static_cast<int>(num_recv_tokens);
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens_int32);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens_int32;
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens_int32, recv_token_begin_idx);
// Add stats for diagnosis
if (cumulative_local_expert_recv_stats != nullptr)
atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens_int32);
if (dispatch_wait_recv_cost_stats != nullptr) {
atomicAdd(reinterpret_cast<uint64_t*>(dispatch_wait_recv_cost_stats + src_rank), static_cast<uint64_t>(wait_recv_cost));
}
#endif
}
#if 1
#ifdef USE_ROCM
// no needs to reset because there is no iteration
if (lane_id == 0){
volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1);
}
syncwarp();
while (sync_large_warp_counters[warp_group_id] < num_warps_per_group) {}
#else
// asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
#endif
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
// Copy tokens
EP_DEVICE_ASSERT(num_scales <= 64);
for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
// Copy source info
const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
if (lane_id == 0)
recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
syncwarp();
// Copy data
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// Copy scales
if constexpr (kUseFP8) {
// Equivalent CuTe layout:
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
const auto token_idx = recv_token_begin_idx + i;
const auto token_stride = num_elems_per_pack;
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
if (lane_id < num_scales) {
const auto pack_idx = lane_id / num_elems_per_pack;
const auto elem_idx = lane_id % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
if (lane_id + kWarpSize < num_scales) {
const auto pack_idx = (lane_id + kWarpSize) / num_elems_per_pack;
const auto elem_idx = (lane_id + kWarpSize) % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + kWarpSize));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
}
}
#endif
}
#if !defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
#endif
}
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
void dispatch(void* packed_recv_x,
void* packed_recv_x_scales,
int* packed_recv_src_info,
int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
int* mask_buffer_ptr,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const topk_idx_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases) {
void* rdma_recv_x,
int64_t* rdma_recv_count,
void* rdma_x,
const void* x,
const int64_t* topk_idx,
int64_t* next_clean,
int num_next_clean_int,
int num_tokens,
int hidden,
int num_max_dispatch_tokens_per_rank,
int num_topk,
int num_experts,
int rank,
int num_ranks,
bool use_fp8,
bool round_scale,
bool use_ue8m0,
void* workspace,
int num_device_sms,
hipStream_t stream,
int phases) {
constexpr int kNumMaxTopK = 11;
// const int num_warp_groups = ceil_div(num_experts, num_device_sms);
// const int num_warps_per_group = 32 / num_warp_groups;
// EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
// EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);
// const auto num_warps = num_warp_groups * num_warps_per_group;
// const auto num_sms = ceil_div(num_experts, num_warp_groups);
// EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
// // Workspace checks
// auto atomic_counter_per_expert = static_cast<int*>(workspace);
// auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
// EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
// // FP8 checks
// if (use_ue8m0)
// EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`");
// #define DISPATCH_LAUNCH_CASE(hidden) { \
// auto dispatch_func = dispatch<false, false, hidden>; \
// if (use_fp8 and not use_ue8m0) \
// dispatch_func = dispatch<true, false, hidden>; \
// if (use_fp8 and use_ue8m0) \
// dispatch_func = dispatch<true, true, hidden>; \
// LAUNCH_KERNEL(&cfg, dispatch_func, \
// packed_recv_x, packed_recv_x_scales, \
// packed_recv_src_info, packed_recv_layout_range, \
// packed_recv_count, \
// cumulative_local_expert_recv_stats, \
// dispatch_wait_recv_cost_stats, \
// rdma_recv_x, rdma_recv_count, rdma_x, \
// x, topk_idx, \
// atomic_counter_per_expert, atomic_finish_counter_per_expert, \
// next_clean, num_next_clean_int, \
// num_tokens, num_max_dispatch_tokens_per_rank, \
// num_topk, num_experts, rank, num_ranks, \
// num_warp_groups, num_warps_per_group, \
// round_scale, phases); } break
// SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
// SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
// #undef DISPATCH_LAUNCH_CASE
}
template <int kNumSendUnrolls>
__forceinline__ __device__ int logfmt_encode(void* buffer, nv_bfloat162 *shared_amaxmin, const int& lane_id) {
// constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
// constexpr float kLogThreshold = 0;
// constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
// constexpr int kNumBits = 10;
// constexpr int kNumValues = 1 << (kNumBits - 1);
// int4 int4_values[kNumSendUnrolls];
// const auto& uint32_values = reinterpret_cast<uint32_t*>(int4_values);
// const auto& bf162_values = reinterpret_cast<nv_bfloat162*>(int4_values);
// // Calculate lane offset
// const auto& ld_buffer = reinterpret_cast<uint32_t*>(static_cast<uint8_t*>(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4)));
// const auto& st_buffer = reinterpret_cast<uint32_t*>(static_cast<uint8_t*>(buffer) + lane_id * (kNumSendUnrolls * sizeof(int4) * 10 / 16));
// // Local log amax
// auto bf162_amax = __nv_bfloat162(CUDART_ZERO_BF16, CUDART_ZERO_BF16);
// auto bf162_amin = __nv_bfloat162(CUDART_INF_BF16, CUDART_INF_BF16);
// uint32_t local_signs = 0;
// #pragma unroll
// for (int k = 0; k < kNumSendUnrolls * kNumElemsPerInt4 / 2; ++ k) {
// // TODO: eliminate bank conflicts
// uint32_values[k] = ld_buffer[k];
// local_signs |= ((uint32_values[k] >> 15) & 1) << (k * 2);
// local_signs |= ((uint32_values[k] >> 31) & 1) << (k * 2 + 1);
// uint32_values[k] &= 0x7fff7fff;
// bf162_amax = __hmax2(bf162_amax, bf162_values[k]);
// bf162_amin = __hmin2(bf162_amin, bf162_values[k]);
// }
// // Reduce per 128 channels
// // TODO: figure out how hardware do 2-byte min/max
// auto amax = std::max(static_cast<float>(bf162_amax.x), static_cast<float>(bf162_amax.y));
// auto amin = std::min(static_cast<float>(bf162_amin.x), static_cast<float>(bf162_amin.y));
// constexpr static int kNumLanesToReduce = 128 * sizeof(nv_bfloat16) / (kNumSendUnrolls * sizeof(int4));
// amax = warp_reduce_max<kNumLanesToReduce>(amax);
// amin = warp_reduce_min<kNumLanesToReduce>(amin);
// // Write min/max into the shared memory
// if (shared_amaxmin != nullptr)
// *shared_amaxmin = __nv_bfloat162(amax, amin);
// __syncwarp();
// // Calculate log amin/amax float
// const auto& log_amax = log2f_approx(amax);
// const auto& log_amin = fmaxf(log2f_approx(amin), log_amax - kMinClip);
// const bool& enable_cast = warp_reduce_and<kNumLanesToReduce, true>(log_amax < kLogThreshold and log_amin < log_amax);
// // Case into LogFMT-10 if satisfied
// if (enable_cast) {
// const auto step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
// const auto step_inv = 1.0f / step;
// const auto rounding = 2.0f - log2f_approx((1.0f + exp2f_approx(step)) * 0.5f) * step_inv;
// const auto fused_rounding = rounding - log_amin * step_inv;
// // Pack every 256 bits into 160 bits
// EP_STATIC_ASSERT(kNumSendUnrolls == 2 or kNumSendUnrolls == 4, "kNumSendUnrolls == 2 or 4 only");
// uint32_t encoded[kNumElemsPerInt4 * 2];
// #pragma unroll 1
// for (int i = 0; i < kNumSendUnrolls / 2; ++ i) {
// #pragma unroll
// for (int k = 0; k < kNumElemsPerInt4; ++ k) {
// const auto& [x, y] = __bfloat1622float2(bf162_values[i * kNumElemsPerInt4 + k]);
// encoded[k * 2 + 0] = __float2uint_rd(fmaxf(log2f_approx(x) * step_inv + fused_rounding, 0));
// encoded[k * 2 + 1] = __float2uint_rd(fmaxf(log2f_approx(y) * step_inv + fused_rounding, 0));
// }
// st_buffer[i * 5 + 0] = (encoded[ 0] >> 0) | (encoded[ 1] << 9) | (encoded[ 2] << 18) | (encoded[ 3] << 27);
// st_buffer[i * 5 + 1] = (encoded[ 3] >> 5) | (encoded[ 4] << 4) | (encoded[ 5] << 13) | (encoded[ 6] << 22) | (encoded[7] << 31);
// st_buffer[i * 5 + 2] = (encoded[ 7] >> 1) | (encoded[ 8] << 8) | (encoded[ 9] << 17) | (encoded[10] << 26);
// st_buffer[i * 5 + 3] = (encoded[10] >> 6) | (encoded[11] << 3) | (encoded[12] << 12) | (encoded[13] << 21) | (encoded[14] << 30);
// st_buffer[i * 5 + 4] = (encoded[14] >> 2) | (encoded[15] << 7) | ((i == 0) ? (local_signs << 16) : (local_signs & 0xffff0000u));
// }
// tma_store_fence();
// __syncwarp();
// }
// // Return TMA copy bytes
// return enable_cast ? (32 * (kNumSendUnrolls * sizeof(int4) * 8 * 10 / 16 / 8)):
// (32 * (kNumSendUnrolls * sizeof(int4)));
}
template <int kNumLanes, int kNumSendUnrolls, int kNumRecvUnrolls>
__forceinline__ __device__ void logfmt_check_amaxmin(uint8_t* meta_buffer, float2* shared_log_amax,
float2* shared_log_amin, int* shared_cast_info,
const int lane_id) {
// constexpr float kLogThreshold = 0;
// constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
// bool enable_cast = true;
// if (lane_id < kNumLanes) {
// // Calculate log amin/amax float
// auto amaxmin2 = reinterpret_cast<uint64_t*>(meta_buffer)[lane_id];
// const auto& bf162_amaxmin = reinterpret_cast<__nv_bfloat162*>(&amaxmin2);
// float log_amax[2], log_amin[2];
// #pragma unroll
// for (int i = 0; i < 2; ++ i) {
// auto amax = static_cast<float>(bf162_amaxmin[i].x);
// auto amin = static_cast<float>(bf162_amaxmin[i].y);
// log_amax[i] = log2f_approx(amax);
// log_amin[i] = amin == 0 ? log_amax[i] - kMinClip : fmaxf(log2f_approx(amin), log_amax[i] - kMinClip);
// enable_cast = enable_cast and log_amax[i] < kLogThreshold and log_amin[i] < log_amax[i];
// }
// shared_log_amax[lane_id] = make_float2(log_amax[0], log_amax[1]);
// shared_log_amin[lane_id] = make_float2(log_amin[0], log_amin[1]);
// }
// const auto& casted = warp_reduce_and<kNumSendUnrolls>(enable_cast) ? 1u << (lane_id / kNumRecvUnrolls): 0u;
// const auto& num_casted_prefix = __popc(warp_reduce_or<kNumRecvUnrolls, true>(casted) & ((1u << (lane_id / kNumRecvUnrolls)) - 1));
// if (lane_id < kNumLanes and lane_id % kNumRecvUnrolls == 0)
// shared_cast_info[lane_id / kNumRecvUnrolls] = (num_casted_prefix << 1) | (casted ? 1u : 0u);
// __syncwarp();
}
template <int kNumRecvUnrolls>
__forceinline__ __device__ void decode_and_accumulate(uint32_t* ld_buffer, float* accum,
const float& log_amax, const float& log_amin,
const bool& enable_cast, const float& weight) {
// if (enable_cast) {
// constexpr int kNumBits = 10;
// constexpr int kNumValues = 1 << (kNumBits - 1);
// const auto& step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
// auto decode = [=](const uint32_t &encoded, const uint32_t &sign) {
// const auto decoded = encoded == 0 ? .0f : exp2f_approx((encoded - 1) * step + log_amin);
// return sign ? -decoded : decoded;
// };
// EP_STATIC_ASSERT(kNumRecvUnrolls == 2 or kNumRecvUnrolls == 4, "kNumRecvUnrolls == 2 or 4 only");
// #pragma unroll
// for (int i = 0; i < kNumRecvUnrolls / 2; ++ i) {
// uint32_t concat[6];
// concat[0] = ld_buffer[i * 5];
// #pragma unroll
// for (int k = 1; k < 5; ++ k)
// concat[k] = (ld_buffer[i * 5 + k - 1] >> (32 - k * 5)) | (ld_buffer[i * 5 + k] << (k * 5));
// concat[5] = ld_buffer[i * 5 + 4] >> 7;
// const uint32_t& local_signs = ld_buffer[i * 5 + 4] >> 16;
// #pragma unroll
// for (int k = 0; k < 5; ++ k) {
// accum[i * 16 + k * 3 + 0] += decode((concat[k] >> 0) & 0x1ff, (local_signs >> (k * 3 + 0)) & 1) * weight;
// accum[i * 16 + k * 3 + 1] += decode((concat[k] >> 9) & 0x1ff, (local_signs >> (k * 3 + 1)) & 1) * weight;
// accum[i * 16 + k * 3 + 2] += decode((concat[k] >> 18) & 0x1ff, (local_signs >> (k * 3 + 2)) & 1) * weight;
// }
// accum[i * 16 + 15] += decode(concat[5] & 0x1ff, (local_signs >> 15) & 1) * weight;
// }
// } else {
// #pragma unroll
// for (int k = 0; k < kNumRecvUnrolls * 4; ++ k) {
// auto bf16_pack = *reinterpret_cast<__nv_bfloat162*>(ld_buffer + k);
// accum[k * 2 + 0] += static_cast<float>(bf16_pack.x) * weight;
// accum[k * 2 + 1] += static_cast<float>(bf16_pack.y) * weight;
// }
// }
const int num_warp_groups = DIVUP(num_experts, num_device_sms);
EP_HOST_ASSERT(num_warp_groups <= 16);
const int num_warps_per_group = 16 / num_warp_groups; // 每个kernel最大16个warp
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);
const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_sms = DIVUP(num_experts, num_warp_groups);
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
// Workspace checks
auto atomic_counter_per_expert = static_cast<int*>(workspace);
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
#define DISPATCH_LAUNCH_CASE(hidden) \
{ \
auto dispatch_func = dispatch<false, false, hidden>; \
if(use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, hidden>; \
if(use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, hidden>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, \
dispatch_func, \
packed_recv_x, \
packed_recv_x_scales, \
packed_recv_src_info, \
packed_recv_layout_range, \
packed_recv_count, \
global_atomic_counter, \
mask_buffer_ptr, \
cumulative_local_expert_recv_stats, \
dispatch_wait_recv_cost_stats, \
rdma_recv_x, \
rdma_recv_count, \
rdma_x, \
x, \
topk_idx, \
atomic_counter_per_expert, \
atomic_finish_counter_per_expert, \
next_clean, \
num_next_clean_int, \
num_tokens, \
num_max_dispatch_tokens_per_rank, \
num_topk, \
num_experts, \
rank, \
num_ranks, \
num_warp_groups, \
num_warps_per_group, \
round_scale, \
phases); \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}
template <bool kUseLogFMT, int kHidden, int kNumMaxTopk, int kNumMaxUnrolls>
__global__ __launch_bounds__(1024, 1) void
combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const topk_idx_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int,
int* atomic_clean_flag,
int num_combined_tokens, int hidden, int num_topk,
int num_max_dispatch_tokens_per_rank,
int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
int phases, bool zero_copy) {
// const auto sm_id = __shfl_sync(0xffffffff, static_cast<int>(blockIdx.x), 0);
// const auto num_sms = __shfl_sync(0xffffffff, static_cast<int>(gridDim.x), 0);
// const auto thread_id = static_cast<int>(threadIdx.x);
// const auto num_threads = __shfl_sync(0xffffffff, static_cast<int>(blockDim.x), 0);
// const auto warp_id = __shfl_sync(0xffffffff, thread_id / 32, 0), lane_id = get_lane_id();
// const auto num_local_experts = num_experts / num_ranks;
// const auto warp_group_id = warp_id / num_warps_per_group;
// const auto sub_warp_id = warp_id % num_warps_per_group;
// const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// extern __shared__ __align__(1024) uint8_t smem_buffer[];
// // Data type staffs
// constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
// constexpr int64_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// // Use different unroll factors for send and recv phases
// constexpr int kNumSendUnrolls = kHidden % (32 * 4 * sizeof(int4) / sizeof(nv_bfloat16)) == 0 ? 4 : 2;
// constexpr int kNumRecvUnrolls = 2;
// constexpr int hidden_bf16_int4_pad = align_up(static_cast<int>(hidden_bf16_int4), 32 * kNumSendUnrolls);
// EP_STATIC_ASSERT(kHidden % (32 * 2 * sizeof(int4) / sizeof(nv_bfloat16)) == 0, "Invalid hidden");
// EP_STATIC_ASSERT(kNumSendUnrolls <= kNumMaxUnrolls and kNumRecvUnrolls <= kNumMaxUnrolls, "Invalid unrolls");
// EP_STATIC_ASSERT(hidden_bf16_int4 % kNumSendUnrolls == 0, "Invalid hidden");
// EP_STATIC_ASSERT(kNumSendUnrolls >= kNumRecvUnrolls, "Invalid unroll factors");
// // Message package
// EP_STATIC_ASSERT(kHidden % 128 == 0, "Invalid hidden");
// constexpr int kNumDivisions = kHidden / 128;
// constexpr int kNumMetaBytes = kNumDivisions * sizeof(nv_bfloat162);
// constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16) + kNumMetaBytes;
// EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// // Sending phase
// if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
// goto LOW_LATENCY_COMBINE_RECV;
// // Clean up next buffer
// if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
// #pragma unroll
// for (int i = lane_id; i < num_next_clean_int; i += 32)
// next_clean[i] = 0;
// // Notify before executing `int_p`
// __syncwarp();
// if (lane_id == 0)
// atomic_add_release_global(atomic_clean_flag, num_experts);
// }
// // Issue IBGDA sends
// if (responsible_expert_idx < num_experts) {
// const auto dst_rank = responsible_expert_idx / num_local_experts;
// const auto local_expert_idx = responsible_expert_idx % num_local_experts;
// const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
// const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
// const auto local_x = static_cast<const int4*>(x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
// const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
// const auto rdma_send_x_vec = static_cast<uint8_t*>(rdma_send_x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// // Unpack layout
// int offset, num_tokens_to_send;
// unpack2(layout, num_tokens_to_send, offset);
// // TMA stuffs
// constexpr int kNumTMABufferBytes = sizeof(int4) * 32 * kNumSendUnrolls;
// constexpr int kNumStages = 3;
// constexpr int kNumPrefetch = 1;
// EP_STATIC_ASSERT(kNumStages == 3 and kNumPrefetch == 1, "Invalid stages");
// auto smem_ptr = smem_buffer + warp_id * (kNumStages * (kNumTMABufferBytes + 16) + kNumMetaBytes);
// uint32_t tma_phase = 0;
// auto tma_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<int4*>(smem_ptr + i * (kNumTMABufferBytes + 16)); });
// auto full_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_ptr + i * (kNumTMABufferBytes + 16) + kNumTMABufferBytes); });
// auto meta_buffers = kUseLogFMT ? reinterpret_cast<nv_bfloat162*>(smem_ptr + kNumStages * (kNumTMABufferBytes + 16)) : nullptr;
// EP_STATIC_ASSERT(kNumSendUnrolls * kNumStages <= 12, "TMA buffer size exceed limit");
// // Initialize m-barriers
// if (lane_id < kNumStages) {
// mbarrier_init(full_barriers[lane_id], 1);
// fence_barrier_init();
// }
// __syncwarp();
// constexpr int kNumIters = hidden_bf16_int4_pad / (32 * kNumSendUnrolls);
// auto tma_load_and_arrive = [&](const int& stage_idx, const int4* gmem_ptr, const int& num_bytes) {
// tma_load_1d(tma_buffers[stage_idx], gmem_ptr, full_barriers[stage_idx], num_bytes);
// mbarrier_arrive_and_expect_tx(full_barriers[stage_idx], num_bytes);
// };
// auto get_num_tma_bytes = [&](const int& offset_int4) {
// return min(kNumTMABufferBytes, static_cast<int>((hidden_bf16_int4 - offset_int4) * sizeof(int4)));
// };
// // Issue IBGDA send
// for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {
// const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
// const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
// const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
// // Copy directly to local rank, or copy to buffer and issue RDMA
// const auto src_idx = __shfl_sync(0xffffffff, __ldg(local_src_info + token_idx), 0);
// const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
// const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
// const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
// int num_send_bytes = hidden * sizeof(nv_bfloat16);
// if (not zero_copy or dst_p2p_ptr != 0) {
// // Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
// const auto cpy_src_int4_ptr = zero_copy ? reinterpret_cast<int4*>(buf_ptr) : x_int4;
// const auto cpy_dst_int4_ptr = dst_p2p_ptr == 0 ? reinterpret_cast<int4*>(buf_ptr) : reinterpret_cast<int4*>(dst_p2p_ptr);
// // Prefetch
// if (elect_one_sync())
// tma_load_and_arrive(0, cpy_src_int4_ptr, get_num_tma_bytes(0));
// __syncwarp();
// int tma_offset_bytes = kNumMetaBytes;
// #pragma unroll
// for (int i = lane_id * kNumSendUnrolls, iter_idx = 0; i < hidden_bf16_int4_pad; i += 32 * kNumSendUnrolls, ++ iter_idx) {
// // Load the next iteration
// const int& stage_idx = iter_idx % kNumStages;
// const int& next_stage_idx = (iter_idx + 1) % kNumStages;
// if (iter_idx + 1 < kNumIters and elect_one_sync()) {
// tma_store_wait<kNumStages - kNumPrefetch - 1>();
// const auto& offset_int4 = i + 32 * kNumSendUnrolls;
// tma_load_and_arrive(next_stage_idx, cpy_src_int4_ptr + offset_int4, get_num_tma_bytes(offset_int4));
// }
// __syncwarp();
// // Wait the current TMA arrival
// EP_STATIC_ASSERT(kNumStages < 32, "Too many stages");
// mbarrier_wait<true>(full_barriers[stage_idx], tma_phase, stage_idx);
// if constexpr (kUseLogFMT) {
// // Cast if possible
// constexpr int kNumInt4PerDivision = 128 / kNumElemsPerInt4;
// int num_tma_bytes = logfmt_encode<kNumSendUnrolls>(
// tma_buffers[stage_idx],
// // NOTES: only the leader lane will write the result
// (i % kNumInt4PerDivision == 0) ? meta_buffers + i / kNumInt4PerDivision : nullptr,
// lane_id);
// if (elect_one_sync())
// tma_store_1d(tma_buffers[stage_idx], reinterpret_cast<uint8_t*>(cpy_dst_int4_ptr) + tma_offset_bytes, num_tma_bytes);
// tma_offset_bytes += num_tma_bytes;
// } else {
// // BF16 original values
// if (elect_one_sync())
// tma_store_1d(tma_buffers[stage_idx], cpy_dst_int4_ptr + i, get_num_tma_bytes(i));
// }
// __syncwarp();
// }
// // Store metadata (min/max values) for LogFMT
// if constexpr (kUseLogFMT) {
// num_send_bytes = tma_offset_bytes;
// if (elect_one_sync())
// tma_store_1d(meta_buffers, cpy_dst_int4_ptr, kNumMetaBytes);
// }
// // Flush all stores
// tma_store_wait<0>();
// __syncwarp();
// }
// // Issue RDMA
// // NOTES: for zero-copy mode, we assume the data is already in the send buffer
// if (dst_p2p_ptr == 0)
// nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, num_send_bytes, dst_rank, local_expert_idx, lane_id, token_idx - offset);
// }
// // Put the finishing flag
// EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16);
// asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(num_warps_per_group * 32));
// if (sub_warp_id == 1 and lane_id == 0) {
// while (ld_acquire_global(atomic_clean_flag) == 0);
// auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_flag + global_expert_idx);
// auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
// if (dst_p2p_ptr == 0) {
// nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), 1, dst_rank, local_expert_idx);
// } else {
// st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), 1);
// }
// atomic_add_release_global(atomic_clean_flag, -1);
// }
// __syncwarp();
// // Destroy m-barriers
// if (lane_id < kNumStages) {
// mbarrier_inval(full_barriers[lane_id]);
// fence_barrier_init();
// }
// __syncwarp();
// }
// // Receiving phase
// LOW_LATENCY_COMBINE_RECV:
// if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
// return;
// // Wait all ranks to arrive
// if (responsible_expert_idx < num_experts) {
// EP_DEVICE_ASSERT(num_warps_per_group > 1);
// if (sub_warp_id == 0 and lane_id == 0) {
// auto start_time = clock64();
// while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
// auto wait_recv_cost = clock64() - start_time;
// if (combine_wait_recv_cost_stats != nullptr) {
// const auto& src_rank = responsible_expert_idx / num_local_experts;
// atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);
// }
// }
// }
// cg::this_grid().sync();
// // Reassign warp groups
// constexpr int kMaxNumGroups = 2;
// const int num_decode_warps = hidden_bf16_int4_pad / (kNumRecvUnrolls * 32);
// const int num_groups = min(kMaxNumGroups, (num_threads / 32) / (num_decode_warps + 1));
// const int decode_warp_idx = __shfl_sync(0xffffffff, warp_id % (num_decode_warps + 1), 0);
// const int group_idx = __shfl_sync(0xffffffff, warp_id / (num_decode_warps + 1), 0);
// EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
// EP_DEVICE_ASSERT(num_topk <= 32);
// EP_DEVICE_ASSERT(num_groups > 0);
// if (group_idx < num_groups) {
// constexpr int kNumStages = 3;
// constexpr int kNumTMABufferBytes = 16 * 2 + kHidden * 2;
// constexpr int kNumBF16PerWarpBytes = 32 * kNumRecvUnrolls * kNumElemsPerInt4 * 2;
// constexpr int kNumLogFMTPerWarpBytes = kNumBF16PerWarpBytes / 16 * 10;
// constexpr int kNumDivisionBytes = kNumDivisions * sizeof(uint32_t);
// constexpr int kNumBytesPerGroup = kNumStages * kNumTMABufferBytes + kHidden * 2 + kNumStages * kNumDivisionBytes * 3;
// // Reallocate shared memory
// const auto smem_group_buffer = smem_buffer + kNumBytesPerGroup * group_idx;
// auto full_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_group_buffer + i * kNumTMABufferBytes); });
// auto empty_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint64_t*>(smem_group_buffer + i * kNumTMABufferBytes + 8); });
// auto tma_ld_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint8_t* >(smem_group_buffer + i * kNumTMABufferBytes + 16); });
// auto tma_st_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<uint32_t*>(smem_group_buffer + kNumStages * kNumTMABufferBytes + i * kNumBF16PerWarpBytes); });
// // Redundant when logfmt is disabled
// const auto smem_group_ptr = smem_group_buffer + kNumStages * kNumTMABufferBytes + kHidden * 2;
// auto log_amax_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<float*>(smem_group_ptr + i * kNumDivisionBytes); });
// auto log_amin_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<float*>(smem_group_ptr + kNumStages * kNumDivisionBytes + i * kNumDivisionBytes); });
// auto cast_info_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<int*> (smem_group_ptr + kNumStages * kNumDivisionBytes * 2 + i * kNumDivisionBytes); });
// uint32_t tma_phase = 0;
// EP_STATIC_ASSERT(kNumStages < 32, "Too many stages");
// if (decode_warp_idx == num_decode_warps)
// tma_phase = (1 << kNumStages) - 1;
// // Initialize m-barriers
// if (decode_warp_idx == num_decode_warps and lane_id < kNumStages) {
// mbarrier_init(full_barriers[lane_id], 1);
// mbarrier_init(empty_barriers[lane_id], num_decode_warps);
// }
// asm volatile("bar.sync %0, %1;" :: "r"(group_idx + 1), "r"((num_decode_warps + 1) * 32));
// int stage_idx = 0, topk_idx_by_lane = 0;
// EP_STATIC_ASSERT(kNumMaxTopk <= 32, "Invalid number of topks");
// if (decode_warp_idx == num_decode_warps) {
// // TMA load warp
// for (int token_idx = sm_id + num_sms * group_idx; token_idx < num_combined_tokens; token_idx += num_sms * num_groups) {
// if (lane_id < num_topk)
// topk_idx_by_lane = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id));
// for (int i = 0; i < num_topk; ++ i) {
// int topk_idx_reg = __shfl_sync(0xffffffff, topk_idx_by_lane, i);
// if (topk_idx_reg < 0)
// continue;
// mbarrier_wait<true>(empty_barriers[stage_idx], tma_phase, stage_idx);
// auto buffer = static_cast<uint8_t*>(rdma_recv_x) + (topk_idx_reg * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot;
// if constexpr (kUseLogFMT) {
// logfmt_check_amaxmin<kNumDivisions / 2, kNumSendUnrolls, kNumRecvUnrolls>(
// buffer, reinterpret_cast<float2*>(log_amax_buffers[stage_idx]),
// reinterpret_cast<float2*>(log_amin_buffers[stage_idx]), cast_info_buffers[stage_idx], lane_id);
// }
// if (elect_one_sync()) {
// int num_casted = 0;
// if constexpr (kUseLogFMT) {
// const auto& info = cast_info_buffers[stage_idx][num_decode_warps - 1];
// num_casted = (info >> 1) + (info & 1);
// }
// int num_tma_bytes = num_casted * kNumLogFMTPerWarpBytes + (num_decode_warps - num_casted) * kNumBF16PerWarpBytes;
// tma_load_1d(tma_ld_buffers[stage_idx], buffer + (kUseLogFMT ? kNumMetaBytes : 0), full_barriers[stage_idx], num_tma_bytes);
// mbarrier_arrive_and_expect_tx(full_barriers[stage_idx], num_tma_bytes);
// }
// __syncwarp();
// stage_idx = (stage_idx + 1) % kNumStages;
// }
// }
// } else {
// // Reduction warps
// float topk_weights_by_lane;
// for (int token_idx = sm_id + num_sms * group_idx; token_idx < num_combined_tokens; token_idx += num_sms * num_groups) {
// if (lane_id < num_topk) {
// topk_idx_by_lane = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id));
// topk_weights_by_lane = __ldg(topk_weights + token_idx * num_topk + lane_id);
// }
// __syncwarp();
// float combined_values[kNumElemsPerInt4 * kNumRecvUnrolls] = {0.0f};
// for (int i = 0; i < num_topk; ++ i) {
// if (__shfl_sync(0xffffffff, topk_idx_by_lane, i) < 0)
// continue;
// const auto& topk_weight = __shfl_sync(0xffffffff, topk_weights_by_lane, i);
// mbarrier_wait<true>(full_barriers[stage_idx], tma_phase, stage_idx);
// if constexpr (kUseLogFMT) {
// const auto& info = cast_info_buffers[stage_idx][decode_warp_idx];
// bool enable_cast = info & 1;
// int num_casted_prefix = info >> 1;
// int tma_offset = kNumLogFMTPerWarpBytes * num_casted_prefix + kNumBF16PerWarpBytes * (decode_warp_idx - num_casted_prefix);
// int division_idx = decode_warp_idx * (kNumRecvUnrolls * 2) + lane_id * kNumRecvUnrolls / 16;
// decode_and_accumulate<kNumRecvUnrolls>(
// reinterpret_cast<uint32_t*>(tma_ld_buffers[stage_idx] + tma_offset + (enable_cast ? kNumLogFMTPerWarpBytes : kNumBF16PerWarpBytes) / 32 * lane_id),
// combined_values, log_amax_buffers[stage_idx][division_idx], log_amin_buffers[stage_idx][division_idx], enable_cast, topk_weight
// );
// } else {
// int tma_offset = kNumBF16PerWarpBytes * decode_warp_idx;
// decode_and_accumulate<kNumRecvUnrolls>(
// reinterpret_cast<uint32_t*>(tma_ld_buffers[stage_idx] + tma_offset + kNumBF16PerWarpBytes / 32 * lane_id),
// combined_values, 0, 0, false, topk_weight
// );
// }
// if (elect_one_sync())
// mbarrier_arrive(empty_barriers[stage_idx]);
// stage_idx = (stage_idx + 1) % kNumStages;
// }
// tma_store_wait<0>();
// #pragma unroll
// for (int k = 0; k < kNumRecvUnrolls * 4; ++ k) {
// auto combined_pack = __nv_bfloat162(combined_values[k * 2], combined_values[k * 2 + 1]);
// tma_st_buffers[decode_warp_idx][kNumRecvUnrolls * 4 * lane_id + k] = *reinterpret_cast<uint32_t*>(&combined_pack);
// }
// tma_store_fence();
// if (elect_one_sync()) {
// tma_store_1d(tma_st_buffers[decode_warp_idx],
// static_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4 + decode_warp_idx * kNumRecvUnrolls * 32,
// kNumBF16PerWarpBytes);
// }
// __syncwarp();
// }
// }
// }
__launch_bounds__(1024, 1) __global__ void combine(void* combined_x,
void* rdma_recv_x,
int* rdma_recv_flag,
void* rdma_send_x,
const void* x,
const int64_t* topk_idx,
const float* topk_weights,
const int* src_info,
const int64_t* layout_range,
int* global_atomic_counter,
int* mask_buffer_ptr,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean,
int num_next_clean_int,
int* atomic_clean_flag,
int num_combined_tokens,
int hidden,
int num_topk,
int num_max_dispatch_tokens_per_rank,
int num_experts,
int rank,
int num_ranks,
int num_warp_groups,
int num_warps_per_group,
int phases,
bool zero_copy) {
#if !defined(ROCM_DISABLE_CTX)
__shared__ rocshmem::rocshmem_ctx_t ctx;
rocshmem::rocshmem_wg_ctx_create(0, &ctx);
#endif
// const auto sm_id = static_cast<int>(blockIdx.x);
// const auto num_sms = static_cast<int>(gridDim.x);
// const auto thread_id = static_cast<int>(threadIdx.x);
// const auto num_threads = static_cast<int>(blockDim.x);
// const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
// const auto num_local_experts = num_experts / num_ranks;
// const auto warp_group_id = warp_id / kNumWarpsPerGroup;
// const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
// const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
// // Data type staffs
// constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(gpu_bfloat16_t);
// const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// // Message package
// // BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
// constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(gpu_bfloat16_t);
// EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// __syncthreads();
// #ifdef USE_ROCM
// // 16 is the max possible number of warps in AMD GPUs
// constexpr int kMaxNumWarps = 1024 / kWarpSize;
// __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
// if (threadIdx.x==0){
// // printf("combine");
// #pragma unroll
// for (int i = 0; i < kMaxNumWarps; ++i) {
// sync_large_warp_counters[i] = 0;
// }
// }
// __syncthreads();
// #endif
// // Sending phase
// if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
// goto LOW_LATENCY_COMBINE_RECV;
// // Clean up next buffer
// if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
// #pragma unroll
// for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
// next_clean[i] = 0;
// // Notify before executing `int_p`
// syncwarp();
// if (lane_id == 0)
// atomic_add_release_global(atomic_clean_flag, num_experts);
// }
// // Issue IBGDA sends
// if (responsible_expert_idx < num_experts) {
// const auto dst_rank = responsible_expert_idx / num_local_experts;
// const auto local_expert_idx = responsible_expert_idx % num_local_experts;
// const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
// const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
// const auto local_x = reinterpret_cast<const int4*>(x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
// const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
// const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
// local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// // Unpack layout
// int offset, num_tokens_to_send;
// unpack2(layout, num_tokens_to_send, offset);
// // Issue IBGDA send
// for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
// const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
// const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
// const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
// // Copy directly to local rank, or copy to buffer and issue RDMA
// auto src_idx = __ldg(local_src_info + token_idx);
// const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
// const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
// if (dst_rank == rank) {
// const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
// UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
// } else {
// const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
// if (not zero_copy)
// UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
// //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(gpu_bfloat16_t), dst_rank, local_expert_idx, lane_id, token_idx - offset);
// #if defined(ROCM_DISABLE_CTX)
// internode::shmemx_int8_put_nbi_warp(
// #else
// internode::shmem_ctx_schar_put_nbi_warp(ctx,
// #endif
// reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank);
// #if defined(ROCM_DISABLE_CTX)
// internode::shmem_fence();
// #else
// internode::shmem_ctx_quiet(ctx);
// #endif
// }
// }
// // Put finishing flag
// EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
// #ifdef USE_ROCM
// if (lane_id == 0){
// volatile int ret = __hip_atomic_fetch_add(
// &sync_large_warp_counters[warp_group_id], 1,
// __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
// }
// syncwarp();
// while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup));
// #else
// asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
// #endif
// if (sub_warp_id == 1 and lane_id == 0) {
// while (ld_acquire_global(atomic_clean_flag) == 0);
// if (dst_rank != rank) {
// #ifdef USE_ROCM
// #if defined(ROCM_DISABLE_CTX)
// internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank);
// #else
// internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank);
// #endif
// #else
// nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
// #endif
// } else {
// st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
// }
// atomic_add_release_global(atomic_clean_flag, -1);
// }
// syncwarp();
// }
// // Receiving phase
// LOW_LATENCY_COMBINE_RECV:
// if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
// return;
// // Wait all ranks to arrive and notify PCIe usage
// if (responsible_expert_idx < num_experts) {
// EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
// if (sub_warp_id == 0 and lane_id == 0){
// while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0);
// }
// }
// grid_barrier(global_atomic_counter, num_sms);
// // Reduce tokens with FP8 cast
// EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
// EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
// if (thread_id < hidden_bf16_int4) {
// for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// // Read top-k indices and weights
// int reg_topk_idx[kNumMaxTopk];
// float reg_topk_weights[kNumMaxTopk];
// #pragma unroll
// for (int i = 0; i < num_topk; ++ i) {
// reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
// reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
// }
// float combined_values[kNumElemsPerInt4] = {0.0f};
// #pragma unroll
// for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// // Read from sources
// auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
// auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
// // Reduce
// auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
// const auto x_bf16 = reinterpret_cast<gpu_bfloat16_t*>(&x_vec);
// #pragma unroll
// for (int j = 0; j < kNumElemsPerInt4; ++ j)
// combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
// }
// // Write results
// int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
// auto combined_bf16 = reinterpret_cast<gpu_bfloat16_t*>(&combined_values);
// #pragma unroll
// for (int j = 0; j < kNumElemsPerInt4; ++ j)
// combined_bf16[j] = static_cast<gpu_bfloat16_t>(combined_values[j]);
// (reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
// }
// }
#if !defined(ROCM_DISABLE_CTX)
rocshmem::rocshmem_wg_ctx_destroy(&ctx);
#endif
}
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const topk_idx_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
void* rdma_recv_x,
int64_t* rdma_recv_flag,
void* rdma_send_x,
const void* x,
const int64_t* topk_idx,
const float* topk_weights,
const int* src_info,
const int64_t* layout_range,
int* global_atomic_counter,
int* mask_buffer_ptr,
int64_t* combine_wait_recv_cost_stats,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int64_t* next_clean,
int num_next_clean_int,
int num_combined_tokens,
int hidden,
int num_max_dispatch_tokens_per_rank,
int num_topk,
int num_experts,
int rank,
int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases, bool zero_copy) {
void* workspace,
int num_device_sms,
hipStream_t stream,
int phases,
bool zero_copy) {
constexpr int kNumMaxTopk = 11;
// const int num_warp_groups = ceil_div(num_experts, num_device_sms);
// const int num_warps_per_group = 32 / num_warp_groups;
// const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
// EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
// const auto num_warps = num_warp_groups * num_warps_per_group;
// const auto num_sms = max(ceil_div(num_experts, num_warp_groups),
// num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm));
// // Check workspace
// auto atomic_clean_flag = static_cast<int*>(workspace);
// EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
// EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
// // Online cast cannot use zero-copy
// EP_HOST_ASSERT(not (zero_copy and use_logfmt));
// constexpr int kNumStages = 3;
// constexpr int kNumMaxUnrolls = 4;
// constexpr int kMaxNumGroups = 2;
// // Send buffer size
// const int num_meta_bytes = hidden / 128 * 4;
// const int num_send_tma_bytes = 32 * sizeof(int4) * kNumMaxUnrolls + 16;
// const int smem_send_size = num_warps * (kNumStages * num_send_tma_bytes + num_meta_bytes);
// // Receive buffer size
// const int num_recv_tma_bytes = 16 + hidden * 2;
// const int smem_recv_size = kMaxNumGroups * (kNumStages * num_recv_tma_bytes + hidden * 2 + kNumStages * num_meta_bytes * 3);
// // Total requirement
// const int smem_size = max(smem_send_size, smem_recv_size);
// #define COMBINE_LAUNCH_CASE(hidden) { \
// auto combine_func = use_logfmt ? \
// combine<true, hidden, kNumMaxTopk, kNumMaxUnrolls> : \
// combine<false, hidden, kNumMaxTopk, kNumMaxUnrolls>; \
// SET_SHARED_MEMORY_FOR_TMA(combine_func); \
// LAUNCH_KERNEL(&cfg, combine_func, \
// combined_x, \
// rdma_recv_x, rdma_recv_flag, rdma_send_x, \
// x, topk_idx, topk_weights, src_info, layout_range, \
// combine_wait_recv_cost_stats, \
// next_clean, num_next_clean_int, \
// atomic_clean_flag, \
// num_combined_tokens, hidden, num_topk, \
// num_max_dispatch_tokens_per_rank, \
// num_experts, rank, num_ranks, \
// num_warp_groups, num_warps_per_group, \
// phases, zero_copy); } break
// SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
// SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
const int num_warp_groups = DIVUP(num_experts, num_device_sms);
const int num_warps_per_group = 16 / num_warp_groups;
const int num_recv_per_sm = DIVUP(num_combined_tokens, num_device_sms);
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_sms = max(DIVUP(num_experts, num_warp_groups), num_recv_per_sm == 0 ? 1 : DIVUP(num_combined_tokens, num_recv_per_sm));
// Check workspace
auto atomic_clean_flag = static_cast<int*>(workspace);
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
// Online cast cannot use zero-copy
EP_HOST_ASSERT(not(zero_copy and use_logfmt));
EP_HOST_ASSERT(use_logfmt == 0);
constexpr int kNumMaxUnrolls = 4;
#ifdef USEING_TMA
constexpr int kNumStages = 3;
constexpr int kMaxNumGroups = 2;
// Send buffer size
const int num_meta_bytes = hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL * 4;
const int num_send_tma_bytes = 32 * sizeof(int4) * kNumMaxUnrolls + 16;
const int smem_send_size = num_warps * (kNumStages * num_send_tma_bytes + num_meta_bytes);
// Receive buffer size
const int num_recv_tma_bytes = 16 + hidden * 2;
const int smem_recv_size = kMaxNumGroups * (kNumStages * num_recv_tma_bytes + hidden * 2 + kNumStages * num_meta_bytes * 3);
// Total requirement
const int smem_size = max(smem_send_size, smem_recv_size);
#endif
// #define COMBINE_LAUNCH_CASE(hidden) \
// { \
// auto combine_func = combine<false, hidden, kNumMaxTopk, kNumMaxUnrolls>; \
// LAUNCH_KERNEL(&cfg, \
// combine_func, \
// combined_x, \
// rdma_recv_x, \
// rdma_recv_flag, \
// rdma_send_x, \
// x, \
// topk_idx, \
// topk_weights, \
// src_info, \
// layout_range, \
// global_atomic_counter, \
// mask_buffer_ptr, \
// combine_wait_recv_cost_stats, \
// next_clean, \
// num_next_clean_int, \
// atomic_clean_flag, \
// num_combined_tokens, \
// hidden, \
// num_topk, \
// num_max_dispatch_tokens_per_rank, \
// num_experts, \
// rank, \
// num_ranks, \
// num_warp_groups, \
// num_warps_per_group, \
// phases, \
// zero_copy); \
// } \
// break
// SETUP_LAUNCH_CONFIG(num_sms, num_warps* kWarpSize, stream);
// SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
// #undef COMBINE_LAUNCH_CASE
}
template <int kNumThreads>
__launch_bounds__(kNumThreads, 1) __global__ void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* mask_tensor) {
const auto num_sms = static_cast<int>(gridDim.x);
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_threads = num_sms * kNumThreads;
const auto thread_id = sm_id * kNumThreads + static_cast<int>(threadIdx.x);
for (int rank_id = thread_id; rank_id < num_ranks; rank_id += num_threads) {
mask_tensor[rank_id] = mask_buffer_ptr[rank_id];
}
}
void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* mask_tensor, hipStream_t stream) {
constexpr int num_sms = 1;
constexpr int kNumThreads = 1024;
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, query_mask_buffer<kNumThreads>, mask_buffer_ptr, num_ranks, mask_tensor);
}
template <int kNumThreads>
__launch_bounds__(kNumThreads, 1) __global__ void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
if (sm_id == 0 && thread_id == 0) {
atomicExch(mask_buffer_ptr + rank_to_mask, mask ? 1 : 0);
}
}
void update_mask_buffer(int* mask_buffer_ptr, int rank, bool mask, hipStream_t stream) {
constexpr int num_sms = 1;
constexpr int kNumThreads = 64;
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, update_mask_buffer<kNumThreads>, mask_buffer_ptr, rank, mask);
}
template <int kNumThreads>
__launch_bounds__(kNumThreads, 1) __global__ void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks) {
auto thread_id = static_cast<int>(threadIdx.x);
#pragma unroll
for (int i = thread_id; i < num_ranks; i += kNumThreads)
mask_buffer_ptr[i] = 0;
}
void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, hipStream_t stream) {
constexpr int num_sms = 1;
constexpr int kNumThreads = 64;
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, clean_mask_buffer<kNumThreads>, mask_buffer_ptr, num_ranks);
}
} // namespace internode_ll
} // namespace deep_ep
#endif
......@@ -125,6 +125,10 @@ __device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
__hip_atomic_store(const_cast<int *>(ptr), val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_SYSTEM);
}
__device__ __forceinline__ void st_release_sys_global(const int64_t *ptr, int64_t val) {
__hip_atomic_store(const_cast<int64_t *>(ptr), val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_SYSTEM);
}
__device__ __forceinline__ void st_release_cta(const int *ptr, int val) {
__hip_atomic_store(const_cast<int *>(ptr), val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_WORKGROUP);
}
......@@ -157,6 +161,12 @@ __device__ __forceinline__ int ld_acquire_global(const int *ptr) {
return ret;
}
__device__ __forceinline__ int64_t ld_acquire_global(const int64_t *ptr) {
int64_t ret;
ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT);
return ret;
}
__device__ __forceinline__ int atomic_add_release_global(const int *ptr, int value) {
int ret;
// ret = __hip_atomic_fetch_add(const_cast<int *>(ptr), value, __ATOMIC_RELEASE,
......@@ -165,6 +175,12 @@ __device__ __forceinline__ int atomic_add_release_global(const int *ptr, int val
return ret;
}
__device__ __forceinline__ int ld_relaxed_global(const int *ptr) {
int ret;
ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
return ret;
}
__device__ __forceinline__ int ld_acquire_cta(const int *ptr) {
int ret;
ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_WORKGROUP);
......@@ -245,6 +261,11 @@ __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val)
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
}
__device__ __forceinline__ void st_na_release(const int64_t *ptr, int64_t val) {
int64_t *non_const_ptr = const_cast<int64_t *>(ptr);
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
}
// TODO:: apply "st.global.L1::no_allocate" in ROCM
template <typename dtype_t>
__device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t &value) {
......@@ -279,6 +300,22 @@ __forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_s
token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);
}
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
dtype_b_t packed;
auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
unpacked_ptr[0] = x, unpacked_ptr[1] = y;
return packed;
}
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
x = unpacked_ptr[0], y = unpacked_ptr[1];
}
template <typename dtype_t>
__device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, "");
......@@ -290,15 +327,47 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t *>(recv_int_values);
}
__forceinline__ __device__ int warp_reduce_sum(int value) {
if constexpr (kWarpSize == 64)
value += shfl_xor<int>(value, 32);
value += shfl_xor<int>(value, 16);
value += shfl_xor<int>(value, 8);
value += shfl_xor<int>(value, 4);
value += shfl_xor<int>(value, 2);
value += shfl_xor<int>(value, 1);
return value;
#ifdef USE_ROCM
constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 240.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
#else
constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 448.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
#endif
__forceinline__ __device__ float fast_pow2(int x) {
// We can ensure `-126 <= x and x <= 127`
uint32_t bits_x = (x + 127) << 23;
return *reinterpret_cast<float*>(&bits_x);
}
__forceinline__ __device__ int fast_log2_ceil(float x) {
auto bits_x = *reinterpret_cast<uint32_t*>(&x);
auto exp_x = (bits_x >> 23) & 0xff;
auto man_bits = bits_x & ((1 << 23) - 1);
return exp_x - 127 + (man_bits != 0);
}
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) {
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE4M3;
scale = kFinfoAmaxE4M3 / amax;
}
}
template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
__forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {
if constexpr (kIsUE8M0) {
return static_cast<uint8_t>((*reinterpret_cast<uint32_t*>(&value)) >> 23);
} else {
return value;
}
}
__forceinline__ __device__ int get_lane_id() {
......@@ -340,4 +409,95 @@ __forceinline__ __device__ void barrier_block(int **barrier_signal_ptrs, int ran
}
__syncthreads();
}
// Operation functors
template <typename T>
struct ReduceSum {
__device__ T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct ReduceMax {
__device__ T operator()(T a, T b) const { return a > b ? a : b; }
};
template <typename T>
struct ReduceMin {
__device__ T operator()(T a, T b) const { return a < b ? a : b; }
};
template <typename T>
struct ReduceAnd {
__device__ T operator()(T a, T b) const { return a & b; }
};
template <typename T>
struct ReduceOr {
__device__ T operator()(T a, T b) const { return a | b; }
};
// Unified reduction function
template <int kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
__forceinline__ __device__ T warp_reduce(T value, Op op) {
EP_STATIC_ASSERT(kNumLanesPerGroup == kWarpSize or kNumLanesPerGroup == 32 or
kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or kNumLanesPerGroup == 4 or
kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
"Invalid number of lanes");
constexpr uint32_t mask = 0xffffffff;
if constexpr (kIntergroupReduce) {
if constexpr (kNumLanesPerGroup <= 1)
value = op(value, shfl_xor(value, 1));
if constexpr (kNumLanesPerGroup <= 2)
value = op(value, shfl_xor(value, 2));
if constexpr (kNumLanesPerGroup <= 4)
value = op(value, shfl_xor(value, 4));
if constexpr (kNumLanesPerGroup <= 8)
value = op(value, shfl_xor(value, 8));
if constexpr (kNumLanesPerGroup <= 16)
value = op(value, shfl_xor(value, 16));
if constexpr(kWarpSize == 64){
if constexpr (kNumLanesPerGroup <= 32)
value = op(value, shfl_xor(value, 32));
}
} else {
if constexpr(kWarpSize == 64){
if constexpr (kNumLanesPerGroup >= kWarpSize)
value = op(value, shfl_xor(value, 32));
}
if constexpr (kNumLanesPerGroup >= 32)
value = op(value, shfl_xor(value, 16));
if constexpr (kNumLanesPerGroup >= 16)
value = op(value, shfl_xor(value, 8));
if constexpr (kNumLanesPerGroup >= 8)
value = op(value, shfl_xor(value, 4));
if constexpr (kNumLanesPerGroup >= 4)
value = op(value, shfl_xor(value, 2));
if constexpr (kNumLanesPerGroup >= 2)
value = op(value, shfl_xor(value, 1));
}
return value;
}
// Convenience aliases
template <int kNumLanesPerGroup = kWarpSize, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_sum(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
}
template <int kNumLanesPerGroup = kWarpSize, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_max(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceMax<T>{});
}
template <int kNumLanesPerGroup = kWarpSize, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_min(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceMin<T>{});
}
template <int kNumLanesPerGroup = kWarpSize, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_and(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceAnd<T>{});
}
template <int kNumLanesPerGroup = kWarpSize, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_or(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceOr<T>{});
}
} // namespace deep_ep
......@@ -39,7 +39,7 @@ class Buffer:
allow_nvlink_for_low_latency_mode: bool = True,
allow_mnnvl: bool = False,
explicitly_destroy: bool = False,
use_default_stream_as_comm_stream: bool = True,
enable_shrink: bool = False,
) -> None:
"""
Initialize the communication buffer.
......@@ -59,6 +59,7 @@ class Buffer:
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically.
"""
check_nvlink_connections(group)
......@@ -70,6 +71,7 @@ class Buffer:
self.num_rdma_bytes = num_rdma_bytes
self.low_latency_mode = low_latency_mode
self.explicitly_destroy = explicitly_destroy
self.enable_shrink = enable_shrink
self.runtime = deep_ep_cpp.Buffer(
self.rank,
self.group_size,
......@@ -77,7 +79,7 @@ class Buffer:
num_rdma_bytes,
low_latency_mode,
explicitly_destroy,
use_default_stream_as_comm_stream,
enable_shrink
)
# Synchronize device IDs
......@@ -989,3 +991,31 @@ class Buffer:
return self.runtime.get_next_low_latency_combine_buffer(
num_max_dispatch_tokens_per_rank, hidden, num_experts
)
def low_latency_update_mask_buffer(self, rank_to_mask: int, mask: bool = False):
"""
Mask (unmask) a rank during communication (dispatch, combine, and clean)
Arguments:
rank: the rank to mask (unmask).
mask: if True, will mask the rank (do not recvfrom/sendto the rank), otherwise will unmask the rank.
"""
self.runtime.low_latency_update_mask_buffer(rank_to_mask, mask)
def low_latency_query_mask_buffer(self, mask_status: torch.Tensor):
"""
Query the mask status of all ranks
Arguments:
mask_status: `[num_ranks]` with `torch.int`, the mask status of each rank. `1` means mask and `0` means unmasked.
"""
self.runtime.low_latency_query_mask_buffer(mask_status)
def low_latency_clean_mask_buffer(self):
"""
Clean the mask buffer
"""
self.runtime.low_latency_clean_mask_buffer()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment