Commit 0395bf27 authored by lishen's avatar lishen
Browse files

nmz上normal-dispatch优化

parent a42ecdc0
......@@ -21,7 +21,7 @@ namespace internode {
extern shmem_team_t cpu_rdma_team;
struct SourceMeta {
int src_rdma_rank, is_token_in_nvl_rank_bits;
int src_rdma_rank, is_token_in_nvl_rank_bits; // sizeof(SourceMeta) = 8
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers");
......@@ -619,47 +619,40 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
}
EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);
//////////////// 复制数据到发送缓冲区 ////////////////
// 复制源元数据到对称发送缓冲区
if(lane_id < num_topk_ranks) {
st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);
}
for(int i = 0; i < num_topk_ranks; ++i) {
dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;
}
// 复制 `x` 到对称发送缓冲区
auto st_broadcast = [=](const int key, const int4& value) {
#pragma unroll
for(int j = 0; j < num_topk_ranks; ++j) {
st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);
}
};
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);
#pragma unroll
for(int i = 0; i < num_topk_ranks; ++i) {
dst_send_buffers[i] = reinterpret_cast<int4*>(dst_send_buffers[i]) + hidden_int4;
}
// 复制源元数据到对称发送缓冲区
if(lane_id < num_topk_ranks) {
st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);
}
#pragma unroll
for(int i = 0; i < num_topk_ranks; ++i) {
dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;
}
// 复制 `x_scales` 到对称发送缓冲区
#pragma unroll
for(int i = lane_id; i < num_scales; i += kWarpSize) {
auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
// auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
// auto value = ld_nc_global(x_scales + offset);
#pragma unroll
for(int j = 0; j < num_topk_ranks; ++j) {
st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
}
}
#pragma unroll
for(int i = 0; i < num_topk_ranks; ++i) {
dst_send_buffers[i] = reinterpret_cast<float*>(dst_send_buffers[i]) + num_scales;
}
// 复制 `topk_idx` 和 `topk_weights` 到对称发送缓冲区
#pragma unroll
for(int i = lane_id; i < num_topk * num_topk_ranks; i += kWarpSize) {
auto rank_idx = i / num_topk, copy_idx = i % num_topk;
auto idx_value = static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));
......@@ -899,7 +892,7 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
// 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入
void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted)));
if(lane_id == src_rdma_rank) {
num_tokens_to_recv_from_rdma -= 1;
}
......@@ -918,37 +911,40 @@ dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv
// 获取一个空闲槽位
int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens;
// 复制数据
UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
reinterpret_cast<int4*>(shifted),
ld_nc_global, st_na_global);
shifted = reinterpret_cast<int4*>(shifted) + hidden_int4;
// 设置 src和dst 位置
auto src_gpu_buffer_x = reinterpret_cast<int4*>(reinterpret_cast<int8_t*>(shifted) + sizeof(SourceMeta));
auto src_gpu_buffer_scales = reinterpret_cast<float*>(reinterpret_cast<int8_t*>(src_gpu_buffer_x) + hidden_bytes);
auto src_gpu_buffer_topk_idx = reinterpret_cast<int*>(reinterpret_cast<int8_t*>(src_gpu_buffer_scales) + num_scales * sizeof(float));
auto src_gpu_buffer_topk_weights = reinterpret_cast<float*>(reinterpret_cast<int8_t*>(src_gpu_buffer_topk_idx) + num_topk * sizeof(int));
// 复制源元数据
if(lane_id == 0)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
shifted = reinterpret_cast<SourceMeta*>(shifted) + 1;
auto dst_gpu_buffer_x = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
auto dst_gpu_buffer_scales = nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales;
auto dst_gpu_buffer_topk_idx = nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk;
auto dst_gpu_buffer_topk_weights = nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk;
// 复制 `x_scales`
UNROLLED_WARP_COPY(1, lane_id, num_scales,
nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
reinterpret_cast<float*>(shifted),
ld_nc_global, st_na_global);
shifted = reinterpret_cast<float*>(shifted) + num_scales;
if(lane_id == 0) {
st_na_global(reinterpret_cast<int64_t*>(nvl_channel_src_meta.buffer() + dst_slot_idx),
*reinterpret_cast<int64_t*>(&src_meta));
}
// 复制 `topk_idx` 和 `topk_weights`
if(lane_id < num_topk) {
// 读取
auto idx_value = ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id);
shifted = reinterpret_cast<int*>(shifted) + num_topk;
auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted) + lane_id);
// 转换和写入
idx_value = (idx_value >= dst_rank_expert_begin && idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value);
weight_value = idx_value >= 0 ? weight_value : 0.0f;
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value);
UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
dst_gpu_buffer_x,
src_gpu_buffer_x,
ld_direct_global, st_na_global);
UNROLLED_WARP_COPY(1, lane_id, num_scales,
dst_gpu_buffer_scales,
src_gpu_buffer_scales,
ld_direct_global, st_na_global);
for(int t = lane_id; t < num_topk; t += kWarpSize) {
int idx_val = ld_direct_global(reinterpret_cast<int*>(src_gpu_buffer_topk_idx) + t);
float w_val = ld_direct_global(reinterpret_cast<float*>(src_gpu_buffer_topk_weights) + t);
int new_idx = (idx_val >= dst_rank_expert_begin && idx_val < dst_rank_expert_end)
? (idx_val - dst_rank_expert_begin) : -1;
float new_w = (new_idx != -1) ? w_val : 0.0f;
dst_gpu_buffer_topk_idx[t] = new_idx;
dst_gpu_buffer_topk_weights[t] = new_w;
}
// 在NVL缓冲区不足的情况下,提前停止
......
......@@ -54,6 +54,7 @@
}
// HELPER FUNCTIONS
// #####################################################################################
#define DEVICE_INLINE __device__ inline __attribute__((always_inline))
template <typename T>
__device__ __forceinline__ T shfl_xor(const T val, int laneMask, int width = kWarpSize,
......@@ -118,7 +119,6 @@ __device__ __forceinline__ void trap() {
}
__device__ __forceinline__ void memory_fence() {
__threadfence_system();
}
......@@ -151,11 +151,13 @@ __device__ __forceinline__ int ld_relaxed_sys_global(const int *ptr) {
ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
return ret;
}
__device__ __forceinline__ int ld_relaxed_sys_global(const uint64_t *ptr) {
uint64_t ret;
ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
return ret;
}
__device__ __forceinline__ int ld_relaxed_sys_global(const int64_t *ptr) {
int64_t ret;
ret = __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
......@@ -180,7 +182,6 @@ __device__ __forceinline__ int64_t ld_acquire_sys_global(const int64_t *ptr) {
return ret;
}
__device__ __forceinline__ int ld_acquire_global(const int *ptr) {
int ret;
ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT);
......@@ -269,12 +270,22 @@ __device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) {
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
__device__ __forceinline__ void st_na_relaxed(const float *ptr, float val) {
float *non_const_ptr = const_cast<float *>(ptr);
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
__device__ __forceinline__ void st_na_relaxed(const int64_t *ptr, int64_t val) {
int64_t *non_const_ptr = const_cast<int64_t *>(ptr);
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
__device__ __forceinline__ void st_na_relaxed(const int4 *ptr, int4 val) {
int4 *non_const_ptr = const_cast<int4 *>(ptr);
non_const_ptr->x = val.x;
non_const_ptr->y = val.y;
non_const_ptr->z = val.z;
non_const_ptr->w = val.w;
__hip_atomic_store(&(non_const_ptr->x), val.x, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store(&(non_const_ptr->y), val.y, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store(&(non_const_ptr->z), val.z, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store(&(non_const_ptr->w), val.w, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
__device__ __forceinline__ void st_na_release(const int *ptr, int val) {
......@@ -297,6 +308,14 @@ __device__ __forceinline__ void st_na_release(const int64_t *ptr, int64_t val) {
__hip_atomic_store(non_const_ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
}
__device__ __forceinline__ void st_na_release(const int4 *ptr, int4 val) {
int4 *non_const_ptr = const_cast<int4 *>(ptr);
__hip_atomic_store(&(non_const_ptr->x), val.x, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store(&(non_const_ptr->y), val.y, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store(&(non_const_ptr->z), val.z, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store(&(non_const_ptr->w), val.w, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
}
// TODO:: apply "st.global.L1::no_allocate" in ROCM
template <typename dtype_t>
__device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t &value) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment