Commit 8a0688f3 authored by lishen's avatar lishen
Browse files

简化FORCE_NVSHMEM_API宏定义的数量

parent 6b49c021
...@@ -272,8 +272,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -272,8 +272,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr); const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual); const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
} else { } else
#endif #endif
{
#if !defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_schar_put_nbi_warp(ctx, internode::shmem_ctx_schar_put_nbi_warp(ctx,
#else #else
...@@ -281,9 +282,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -281,9 +282,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#endif #endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr), reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank); num_bytes_per_msg, dst_rank);
#if defined(FORCE_NVSHMEM_API)
} }
#endif
} else { } else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr); const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
...@@ -349,19 +348,19 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -349,19 +348,19 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#if defined(FORCE_NVSHMEM_API) #if defined(FORCE_NVSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank); void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) { // P2P enabled if (peer_base_addr) { // P2P enabled
int *rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(nvshmemi_device_state_d.heap_base))); int *rptr_actual = (int *)((char *)(peer_base_addr) +
((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(nvshmemi_device_state_d.heap_base)));
st_na_release(rptr_actual, -num_tokens_sent - 1); st_na_release(rptr_actual, -num_tokens_sent - 1);
} else { } else
#endif #endif
{
#if !defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_long_atomic_add(ctx, internode::shmem_ctx_long_atomic_add(ctx,
#else #else
internode::shmem_long_atomic_add( internode::shmem_long_atomic_add(
#endif #endif
rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
#if defined(FORCE_NVSHMEM_API)
} }
#endif
} else { } else {
st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1); st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
} }
...@@ -648,8 +647,9 @@ combine(void* combined_x, ...@@ -648,8 +647,9 @@ combine(void* combined_x,
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base)); char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual); const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else { } else
#endif #endif
{
#if !defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_schar_put_nbi_warp(ctx, internode::shmem_ctx_schar_put_nbi_warp(ctx,
#else #else
...@@ -657,9 +657,7 @@ combine(void* combined_x, ...@@ -657,9 +657,7 @@ combine(void* combined_x,
#endif #endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
hidden * sizeof(hip_bfloat16), dst_rank); hidden * sizeof(hip_bfloat16), dst_rank);
#if defined(FORCE_NVSHMEM_API)
} }
#endif
} }
} }
...@@ -677,19 +675,19 @@ combine(void* combined_x, ...@@ -677,19 +675,19 @@ combine(void* combined_x,
#if defined(FORCE_NVSHMEM_API) #if defined(FORCE_NVSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank); void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) { if (peer_base_addr) {
int *req_rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_flag + global_expert_idx) - (char *)(nvshmemi_device_state_d.heap_base))); int *req_rptr_actual = (int *)((char *)(peer_base_addr) +
((char *)(rdma_recv_flag + global_expert_idx) - (char *)(nvshmemi_device_state_d.heap_base)));
st_na_release(req_rptr_actual, 1); st_na_release(req_rptr_actual, 1);
} else { } else
#endif #endif
{
#if !defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_long_atomic_add(ctx, internode::shmem_ctx_long_atomic_add(ctx,
#else #else
internode::shmem_long_atomic_add( internode::shmem_long_atomic_add(
#endif #endif
rdma_recv_flag + global_expert_idx, 1, dst_rank); rdma_recv_flag + global_expert_idx, 1, dst_rank);
#if defined(FORCE_NVSHMEM_API)
} }
#endif
} else { } else {
st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1); st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
} }
...@@ -750,7 +748,8 @@ LOW_LATENCY_COMBINE_RECV: ...@@ -750,7 +748,8 @@ LOW_LATENCY_COMBINE_RECV:
#pragma unroll #pragma unroll
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) { for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// Read from sources // Read from sources
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot); auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) +
(reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4); auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
// Reduce // Reduce
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment